diff --git a/egs/commonvoice/ASR/local/compile_hlg.py b/egs/commonvoice/ASR/local/compile_hlg.py deleted file mode 100755 index e07cee86e..000000000 --- a/egs/commonvoice/ASR/local/compile_hlg.py +++ /dev/null @@ -1,166 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This script takes as input lang_dir and generates HLG from - - - H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt - - L, the lexicon, built from lang_dir/L_disambig.pt - - Caution: We use a lexicon that contains disambiguation symbols - - - G, the LM, built from data/lm/G_n_gram.fst.txt - -The generated HLG is saved in $lang_dir/HLG.pt -""" -import argparse -import logging -from pathlib import Path - -import k2 -import torch - -from icefall.lexicon import Lexicon - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lm", - type=str, - default="G_3_gram", - help="""Stem name for LM used in HLG compiling. - """, - ) - parser.add_argument( - "--lang-dir", - type=str, - help="""Input and output directory. - """, - ) - - return parser.parse_args() - - -def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa: - """ - Args: - lang_dir: - The language directory, e.g., data/lang_phone or data/lang_bpe_5000. - lm: - The language stem base name. - - Return: - An FSA representing HLG. - """ - lexicon = Lexicon(lang_dir) - max_token_id = max(lexicon.tokens) - logging.info(f"Building ctc_topo. max_token_id: {max_token_id}") - H = k2.ctc_topo(max_token_id) - L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) - - if Path(f"{lang_dir}/lm/{lm}.pt").is_file(): - logging.info(f"Loading pre-compiled {lm}") - d = torch.load(f"{lang_dir}/lm/{lm}.pt") - G = k2.Fsa.from_dict(d) - else: - logging.info(f"Loading {lm}.fst.txt") - with open(f"{lang_dir}/lm/{lm}.fst.txt") as f: - G = k2.Fsa.from_openfst(f.read(), acceptor=False) - torch.save(G.as_dict(), f"{lang_dir}/lm/{lm}.pt") - - first_token_disambig_id = lexicon.token_table["#0"] - first_word_disambig_id = lexicon.word_table["#0"] - - L = k2.arc_sort(L) - G = k2.arc_sort(G) - - logging.info("Intersecting L and G") - LG = k2.compose(L, G) - logging.info(f"LG shape: {LG.shape}") - - logging.info("Connecting LG") - LG = k2.connect(LG) - logging.info(f"LG shape after k2.connect: {LG.shape}") - - logging.info(type(LG.aux_labels)) - logging.info("Determinizing LG") - - LG = k2.determinize(LG) - logging.info(type(LG.aux_labels)) - - logging.info("Connecting LG after k2.determinize") - LG = k2.connect(LG) - - logging.info("Removing disambiguation symbols on LG") - - LG.labels[LG.labels >= first_token_disambig_id] = 0 - # See https://github.com/k2-fsa/k2/issues/874 - # for why we need to set LG.properties to None - LG.__dict__["_properties"] = None - - assert isinstance(LG.aux_labels, k2.RaggedTensor) - LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0 - - LG = k2.remove_epsilon(LG) - logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") - - LG = k2.connect(LG) - LG.aux_labels = LG.aux_labels.remove_values_eq(0) - - logging.info("Arc sorting LG") - LG = k2.arc_sort(LG) - - logging.info("Composing H and LG") - # CAUTION: The name of the inner_labels is fixed - # to `tokens`. If you want to change it, please - # also change other places in icefall that are using - # it. - HLG = k2.compose(H, LG, inner_labels="tokens") - - logging.info("Connecting LG") - HLG = k2.connect(HLG) - - logging.info("Arc sorting LG") - HLG = k2.arc_sort(HLG) - logging.info(f"HLG.shape: {HLG.shape}") - - return HLG - - -def main(): - args = get_args() - lang_dir = Path(args.lang_dir) - - if (lang_dir / "HLG.pt").is_file(): - logging.info(f"{lang_dir}/HLG.pt already exists - skipping") - return - - logging.info(f"Processing {lang_dir}") - - HLG = compile_HLG(lang_dir, args.lm) - logging.info(f"Saving HLG.pt to {lang_dir}") - torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/commonvoice/ASR/local/compile_hlg.py b/egs/commonvoice/ASR/local/compile_hlg.py new file mode 120000 index 000000000..471aa7fb4 --- /dev/null +++ b/egs/commonvoice/ASR/local/compile_hlg.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compile_hlg.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/local/compile_lg.py b/egs/commonvoice/ASR/local/compile_lg.py deleted file mode 100755 index b871c1cf2..000000000 --- a/egs/commonvoice/ASR/local/compile_lg.py +++ /dev/null @@ -1,141 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This script takes as input lang_dir and generates LG from - - - L, the lexicon, built from lang_dir/L_disambig.pt - - Caution: We use a lexicon that contains disambiguation symbols - - - G, the LM, built from data/lm/G_3_gram.fst.txt - -The generated LG is saved in $lang_dir/LG.pt -""" -import argparse -import logging -from pathlib import Path - -import k2 -import torch - -from icefall.lexicon import Lexicon - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", - type=str, - help="""Input and output directory. - """, - ) - - return parser.parse_args() - - -def compile_LG(lang_dir: str) -> k2.Fsa: - """ - Args: - lang_dir: - The language directory, e.g., data/lang_phone or data/lang_bpe_5000. - - Return: - An FSA representing LG. - """ - lexicon = Lexicon(lang_dir) - L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) - - if Path(f"{lang_dir}/lm/G_3_gram.pt").is_file(): - logging.info("Loading pre-compiled G_3_gram") - d = torch.load(f"{lang_dir}/lm/G_3_gram.pt") - G = k2.Fsa.from_dict(d) - else: - logging.info("Loading G_3_gram.fst.txt") - with open(f"{lang_dir}/lm/G_3_gram.fst.txt") as f: - G = k2.Fsa.from_openfst(f.read(), acceptor=False) - torch.save(G.as_dict(), f"{lang_dir}/lm/G_3_gram.pt") - - first_token_disambig_id = lexicon.token_table["#0"] - first_word_disambig_id = lexicon.word_table["#0"] - - L = k2.arc_sort(L) - G = k2.arc_sort(G) - - logging.info("Intersecting L and G") - LG = k2.compose(L, G) - - logging.info(f"LG shape: {LG.shape}") - - logging.info("Connecting LG") - LG = k2.connect(LG) - logging.info(f"LG shape after k2.connect: {LG.shape}") - - logging.info(type(LG.aux_labels)) - logging.info("Determinizing LG") - - LG = k2.determinize(LG, k2.DeterminizeWeightPushingType.kLogWeightPushing) - logging.info(type(LG.aux_labels)) - - logging.info("Connecting LG after k2.determinize") - LG = k2.connect(LG) - logging.info("Removing disambiguation symbols on LG") - LG.labels[LG.labels >= first_token_disambig_id] = 0 - # See https://github.com/k2-fsa/k2/issues/874 - # for why we need to set LG.properties to None - LG.__dict__["_properties"] = None - # assert isinstance(LG.aux_labels, k2.RaggedTensor) - # LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0 - if isinstance(LG.aux_labels, k2.RaggedTensor): - LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0 - else: - LG.aux_labels[LG.aux_labels >= first_word_disambig_id] = 0 - - LG = k2.remove_epsilon(LG) - - logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") - - LG = k2.connect(LG) - LG.aux_labels = LG.aux_labels.remove_values_eq(0) - - logging.info("Arc sorting LG") - LG = k2.arc_sort(LG) - return LG - - -def main(): - args = get_args() - lang_dir = Path(args.lang_dir) - - if (lang_dir / "LG.pt").is_file(): - logging.info(f"{lang_dir}/LG.pt already exists - skipping") - return - - logging.info(f"Processing {lang_dir}") - - LG = compile_LG(lang_dir) - logging.info(f"Saving LG.pt to {lang_dir}") - torch.save(LG.as_dict(), f"{lang_dir}/LG.pt") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/commonvoice/ASR/local/compile_lg.py b/egs/commonvoice/ASR/local/compile_lg.py new file mode 120000 index 000000000..462d6d3fb --- /dev/null +++ b/egs/commonvoice/ASR/local/compile_lg.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compile_lg.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode_stream.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode_stream.py deleted file mode 100644 index 0d7e86fcf..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode_stream.py +++ /dev/null @@ -1,151 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Wei Kang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from typing import List, Optional, Tuple - -import k2 -import torch -from beam_search import Hypothesis, HypothesisList - -from icefall.utils import AttributeDict - - -class DecodeStream(object): - def __init__( - self, - params: AttributeDict, - cut_id: str, - initial_states: List[torch.Tensor], - decoding_graph: Optional[k2.Fsa] = None, - device: torch.device = torch.device("cpu"), - ) -> None: - """ - Args: - initial_states: - Initial decode states of the model, e.g. the return value of - `get_init_state` in conformer.py - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a HLG. - Used only when decoding_method is fast_beam_search. - device: - The device to run this stream. - """ - if params.decoding_method == "fast_beam_search": - assert decoding_graph is not None - assert device == decoding_graph.device - - self.params = params - self.cut_id = cut_id - self.LOG_EPS = math.log(1e-10) - - self.states = initial_states - - # It contains a 2-D tensors representing the feature frames. - self.features: torch.Tensor = None - - self.num_frames: int = 0 - # how many frames have been processed. (before subsampling). - # we only modify this value in `func:get_feature_frames`. - self.num_processed_frames: int = 0 - - self._done: bool = False - - # The transcript of current utterance. - self.ground_truth: str = "" - - # The decoding result (partial or final) of current utterance. - self.hyp: List = [] - - # how many frames have been processed, after subsampling (i.e. a - # cumulative sum of the second return value of - # encoder.streaming_forward - self.done_frames: int = 0 - - # It has two steps of feature subsampling in zipformer: out_lens=((x_lens-7)//2+1)//2 - # 1) feature embedding: out_lens=(x_lens-7)//2 - # 2) output subsampling: out_lens=(out_lens+1)//2 - self.pad_length = 7 - - if params.decoding_method == "greedy_search": - self.hyp = [params.blank_id] * params.context_size - elif params.decoding_method == "modified_beam_search": - self.hyps = HypothesisList() - self.hyps.add( - Hypothesis( - ys=[params.blank_id] * params.context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - ) - ) - elif params.decoding_method == "fast_beam_search": - # The rnnt_decoding_stream for fast_beam_search. - self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( - decoding_graph - ) - else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - - @property - def done(self) -> bool: - """Return True if all the features are processed.""" - return self._done - - @property - def id(self) -> str: - return self.cut_id - - def set_features( - self, - features: torch.Tensor, - tail_pad_len: int = 0, - ) -> None: - """Set features tensor of current utterance.""" - assert features.dim() == 2, features.dim() - self.features = torch.nn.functional.pad( - features, - (0, 0, 0, self.pad_length + tail_pad_len), - mode="constant", - value=self.LOG_EPS, - ) - self.num_frames = self.features.size(0) - - def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]: - """Consume chunk_size frames of features""" - chunk_length = chunk_size + self.pad_length - - ret_length = min(self.num_frames - self.num_processed_frames, chunk_length) - - ret_features = self.features[ - self.num_processed_frames : self.num_processed_frames + ret_length # noqa - ] - - self.num_processed_frames += chunk_size - if self.num_processed_frames >= self.num_frames: - self._done = True - - return ret_features, ret_length - - def decoding_result(self) -> List[int]: - """Obtain current decoding result.""" - if self.params.decoding_method == "greedy_search": - return self.hyp[self.params.context_size :] # noqa - elif self.params.decoding_method == "modified_beam_search": - best_hyp = self.hyps.get_most_probable(length_norm=True) - return best_hyp.ys[self.params.context_size :] # noqa - else: - assert self.params.decoding_method == "fast_beam_search" - return self.hyp diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode_stream.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode_stream.py new file mode 120000 index 000000000..ca8fed319 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode_stream.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py deleted file mode 100755 index 1f870ca5a..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py +++ /dev/null @@ -1,367 +0,0 @@ -#!/usr/bin/env python3 - -""" -Please see -https://k2-fsa.github.io/icefall/model-export/export-ncnn.html -for more details about how to use this file. - -We use -https://huggingface.co/pfluo/k2fsa-zipformer-chinese-english-mixed -to demonstrate the usage of this file. - -1. Download the pre-trained model - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/pfluo/k2fsa-zipformer-chinese-english-mixed -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -git lfs pull --include "data/lang_char_bpe/L.pt" -git lfs pull --include "data/lang_char_bpe/L_disambig.pt" -git lfs pull --include "data/lang_char_bpe/Linv.pt" -git lfs pull --include "exp/pretrained.pt" - -cd exp -ln -s pretrained.pt epoch-99.pt -popd - -2. Export to ncnn - -./pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \ - --lang-dir $repo/data/lang_char_bpe \ - --exp-dir $repo/exp \ - --use-averaged-model 0 \ - --epoch 99 \ - --avg 1 \ - --decode-chunk-len 32 \ - --num-encoder-layers "2,4,3,2,4" \ - --feedforward-dims "1024,1024,1536,1536,1024" \ - --nhead "8,8,8,8,8" \ - --encoder-dims "384,384,384,384,384" \ - --attention-dims "192,192,192,192,192" \ - --encoder-unmasked-dims "256,256,256,256,256" \ - --zipformer-downsampling-factors "1,2,4,8,2" \ - --cnn-module-kernels "31,31,31,31,31" \ - --decoder-dim 512 \ - --joiner-dim 512 - -cd $repo/exp - -pnnx encoder_jit_trace-pnnx.pt -pnnx decoder_jit_trace-pnnx.pt -pnnx joiner_jit_trace-pnnx.pt - -You can find converted models at -https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-bilingual-zh-en-2023-02-13 - -See ./streaming-ncnn-decode.py -and -https://github.com/k2-fsa/sherpa-ncnn -for usage. -""" - -import argparse -import logging -from pathlib import Path - -import torch -from scaling_converter import convert_scaled_to_non_scaled -from train2 import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import setup_logger, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for averaging. - Note: Epoch counts from 0. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7_streaming/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="The lang dir", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - add_model_arguments(parser) - - return parser - - -def export_encoder_model_jit_trace( - encoder_model: torch.nn.Module, - encoder_filename: str, -) -> None: - """Export the given encoder model with torch.jit.trace() - - Note: The warmup argument is fixed to 1. - - Args: - encoder_model: - The input encoder model - encoder_filename: - The filename to save the exported model. - """ - encoder_model.__class__.forward = encoder_model.__class__.streaming_forward - - decode_chunk_len = encoder_model.decode_chunk_size * 2 - pad_length = 7 - T = decode_chunk_len + pad_length # 32 + 7 = 39 - - logging.info(f"decode_chunk_len: {decode_chunk_len}") - logging.info(f"T: {T}") - - x = torch.zeros(1, T, 80, dtype=torch.float32) - states = encoder_model.get_init_state() - - traced_model = torch.jit.trace(encoder_model, (x, states)) - traced_model.save(encoder_filename) - logging.info(f"Saved to {encoder_filename}") - - -def export_decoder_model_jit_trace( - decoder_model: torch.nn.Module, - decoder_filename: str, -) -> None: - """Export the given decoder model with torch.jit.trace() - - Note: The argument need_pad is fixed to False. - - Args: - decoder_model: - The input decoder model - decoder_filename: - The filename to save the exported model. - """ - y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) - need_pad = torch.tensor([False]) - - traced_model = torch.jit.trace(decoder_model, (y, need_pad)) - traced_model.save(decoder_filename) - logging.info(f"Saved to {decoder_filename}") - - -def export_joiner_model_jit_trace( - joiner_model: torch.nn.Module, - joiner_filename: str, -) -> None: - """Export the given joiner model with torch.jit.trace() - - Note: The argument project_input is fixed to True. A user should not - project the encoder_out/decoder_out by himself/herself. The exported joiner - will do that for the user. - - Args: - joiner_model: - The input joiner model - joiner_filename: - The filename to save the exported model. - - """ - encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] - decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] - encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) - decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) - - traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out)) - traced_model.save(joiner_filename) - logging.info(f"Saved to {joiner_filename}") - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - - setup_logger(f"{params.exp_dir}/log-export/log-export-ncnn") - - logging.info(f"device: {device}") - - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to("cpu") - model.eval() - - convert_scaled_to_non_scaled(model, inplace=True, is_pnnx=True) - - encoder_num_param = sum([p.numel() for p in model.encoder.parameters()]) - decoder_num_param = sum([p.numel() for p in model.decoder.parameters()]) - joiner_num_param = sum([p.numel() for p in model.joiner.parameters()]) - total_num_param = encoder_num_param + decoder_num_param + joiner_num_param - logging.info(f"encoder parameters: {encoder_num_param}") - logging.info(f"decoder parameters: {decoder_num_param}") - logging.info(f"joiner parameters: {joiner_num_param}") - logging.info(f"total parameters: {total_num_param}") - - logging.info("Using torch.jit.trace()") - - logging.info("Exporting encoder") - encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt" - export_encoder_model_jit_trace(model.encoder, encoder_filename) - - logging.info("Exporting decoder") - decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt" - export_decoder_model_jit_trace(model.decoder, decoder_filename) - - logging.info("Exporting joiner") - joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt" - export_joiner_model_jit_trace(model.joiner, joiner_filename) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py new file mode 120000 index 000000000..72e43c297 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py deleted file mode 100755 index 0f84eca83..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py +++ /dev/null @@ -1,369 +0,0 @@ -#!/usr/bin/env python3 - -""" -Please see -https://k2-fsa.github.io/icefall/model-export/export-ncnn.html -for more details about how to use this file. - -We use -https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 -to demonstrate the usage of this file. - -1. Download the pre-trained model - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -git lfs pull --include "data/lang_bpe/bpe.model" -git lfs pull --include "exp/pretrained.pt" - -cd exp -ln -s pretrained.pt epoch-99.pt -popd - -2. Export to ncnn - -./pruned_transducer_stateless7_streaming/export-for-ncnn.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --exp-dir $repo/exp \ - --use-averaged-model 0 \ - --epoch 99 \ - --avg 1 \ - \ - --decode-chunk-len 32 \ - --num-encoder-layers "2,4,3,2,4" \ - --feedforward-dims "1024,1024,2048,2048,1024" \ - --nhead "8,8,8,8,8" \ - --encoder-dims "384,384,384,384,384" \ - --attention-dims "192,192,192,192,192" \ - --encoder-unmasked-dims "256,256,256,256,256" \ - --zipformer-downsampling-factors "1,2,4,8,2" \ - --cnn-module-kernels "31,31,31,31,31" \ - --decoder-dim 512 \ - --joiner-dim 512 - -cd $repo/exp - -pnnx encoder_jit_trace-pnnx.pt -pnnx decoder_jit_trace-pnnx.pt -pnnx joiner_jit_trace-pnnx.pt - -You can find converted models at -https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-en-2023-02-13 - -See ./streaming-ncnn-decode.py -and -https://github.com/k2-fsa/sherpa-ncnn -for usage. -""" - -import argparse -import logging -from pathlib import Path - -import sentencepiece as spm -import torch -from scaling_converter import convert_scaled_to_non_scaled -from train2 import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import setup_logger, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for averaging. - Note: Epoch counts from 0. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7_streaming/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - add_model_arguments(parser) - - return parser - - -def export_encoder_model_jit_trace( - encoder_model: torch.nn.Module, - encoder_filename: str, -) -> None: - """Export the given encoder model with torch.jit.trace() - - Note: The warmup argument is fixed to 1. - - Args: - encoder_model: - The input encoder model - encoder_filename: - The filename to save the exported model. - """ - encoder_model.__class__.forward = encoder_model.__class__.streaming_forward - - decode_chunk_len = encoder_model.decode_chunk_size * 2 - pad_length = 7 - T = decode_chunk_len + pad_length # 32 + 7 = 39 - - logging.info(f"decode_chunk_len: {decode_chunk_len}") - logging.info(f"T: {T}") - - x = torch.zeros(1, T, 80, dtype=torch.float32) - states = encoder_model.get_init_state() - - traced_model = torch.jit.trace(encoder_model, (x, states)) - traced_model.save(encoder_filename) - logging.info(f"Saved to {encoder_filename}") - - -def export_decoder_model_jit_trace( - decoder_model: torch.nn.Module, - decoder_filename: str, -) -> None: - """Export the given decoder model with torch.jit.trace() - - Note: The argument need_pad is fixed to False. - - Args: - decoder_model: - The input decoder model - decoder_filename: - The filename to save the exported model. - """ - y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) - need_pad = torch.tensor([False]) - - traced_model = torch.jit.trace(decoder_model, (y, need_pad)) - traced_model.save(decoder_filename) - logging.info(f"Saved to {decoder_filename}") - - -def export_joiner_model_jit_trace( - joiner_model: torch.nn.Module, - joiner_filename: str, -) -> None: - """Export the given joiner model with torch.jit.trace() - - Note: The argument project_input is fixed to True. A user should not - project the encoder_out/decoder_out by himself/herself. The exported joiner - will do that for the user. - - Args: - joiner_model: - The input joiner model - joiner_filename: - The filename to save the exported model. - - """ - encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] - decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] - encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) - decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) - - traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out)) - traced_model.save(joiner_filename) - logging.info(f"Saved to {joiner_filename}") - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - - setup_logger(f"{params.exp_dir}/log-export/log-export-ncnn") - - logging.info(f"device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to("cpu") - model.eval() - - convert_scaled_to_non_scaled(model, inplace=True, is_pnnx=True) - - encoder_num_param = sum([p.numel() for p in model.encoder.parameters()]) - decoder_num_param = sum([p.numel() for p in model.decoder.parameters()]) - joiner_num_param = sum([p.numel() for p in model.joiner.parameters()]) - total_num_param = encoder_num_param + decoder_num_param + joiner_num_param - logging.info(f"encoder parameters: {encoder_num_param}") - logging.info(f"decoder parameters: {decoder_num_param}") - logging.info(f"joiner parameters: {joiner_num_param}") - logging.info(f"total parameters: {total_num_param}") - - logging.info("Using torch.jit.trace()") - - logging.info("Exporting encoder") - encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt" - export_encoder_model_jit_trace(model.encoder, encoder_filename) - - logging.info("Exporting decoder") - decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt" - export_decoder_model_jit_trace(model.decoder, decoder_filename) - - logging.info("Exporting joiner") - joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt" - export_joiner_model_jit_trace(model.joiner, joiner_filename) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py new file mode 120000 index 000000000..3b36924ef --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-onnx.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-onnx.py deleted file mode 100755 index dd19a947c..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-onnx.py +++ /dev/null @@ -1,647 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) - -""" -This script exports a transducer model from PyTorch to ONNX. -We use the pre-trained model from -https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 -as an example to show how to use this file. -1. Download the pre-trained model -cd egs/librispeech/ASR -repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) -pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" -git lfs pull --include "exp/pretrained.pt" -cd exp -ln -s pretrained.pt epoch-99.pt -popd -2. Export the model to ONNX -./pruned_transducer_stateless7_streaming/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --use-averaged-model 0 \ - --epoch 99 \ - --avg 1 \ - --decode-chunk-len 32 \ - --exp-dir $repo/exp/ -It will generate the following 3 files in $repo/exp - - encoder-epoch-99-avg-1.onnx - - decoder-epoch-99-avg-1.onnx - - joiner-epoch-99-avg-1.onnx -See ./onnx_pretrained.py for how to use the exported models. -""" - -import argparse -import logging -from pathlib import Path -from typing import Dict, List, Tuple - -import onnx -import sentencepiece as spm -import torch -import torch.nn as nn -from decoder import Decoder -from onnxruntime.quantization import QuantType, quantize_dynamic -from scaling_converter import convert_scaled_to_non_scaled -from torch import Tensor -from train import add_model_arguments, get_params, get_transducer_model -from zipformer import Zipformer - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import setup_logger, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7_streaming/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -class OnnxEncoder(nn.Module): - """A wrapper for Zipformer and the encoder_proj from the joiner""" - - def __init__(self, encoder: Zipformer, encoder_proj: nn.Linear): - """ - Args: - encoder: - A Zipformer encoder. - encoder_proj: - The projection layer for encoder from the joiner. - """ - super().__init__() - self.encoder = encoder - self.encoder_proj = encoder_proj - - def forward(self, x: Tensor, states: List[Tensor]) -> Tuple[Tensor, List[Tensor]]: - """Please see the help information of Zipformer.streaming_forward""" - N = x.size(0) - T = x.size(1) - x_lens = torch.tensor([T] * N, device=x.device) - - output, _, new_states = self.encoder.streaming_forward( - x=x, - x_lens=x_lens, - states=states, - ) - - output = self.encoder_proj(output) - # Now output is of shape (N, T, joiner_dim) - - return output, new_states - - -class OnnxDecoder(nn.Module): - """A wrapper for Decoder and the decoder_proj from the joiner""" - - def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): - super().__init__() - self.decoder = decoder - self.decoder_proj = decoder_proj - - def forward(self, y: torch.Tensor) -> torch.Tensor: - """ - Args: - y: - A 2-D tensor of shape (N, context_size). - Returns - Return a 2-D tensor of shape (N, joiner_dim) - """ - need_pad = False - decoder_output = self.decoder(y, need_pad=need_pad) - decoder_output = decoder_output.squeeze(1) - output = self.decoder_proj(decoder_output) - - return output - - -class OnnxJoiner(nn.Module): - """A wrapper for the joiner""" - - def __init__(self, output_linear: nn.Linear): - super().__init__() - self.output_linear = output_linear - - def forward( - self, - encoder_out: torch.Tensor, - decoder_out: torch.Tensor, - ) -> torch.Tensor: - """ - Args: - encoder_out: - A 2-D tensor of shape (N, joiner_dim) - decoder_out: - A 2-D tensor of shape (N, joiner_dim) - Returns: - Return a 2-D tensor of shape (N, vocab_size) - """ - logit = encoder_out + decoder_out - logit = self.output_linear(torch.tanh(logit)) - return logit - - -def add_meta_data(filename: str, meta_data: Dict[str, str]): - """Add meta data to an ONNX model. It is changed in-place. - Args: - filename: - Filename of the ONNX model to be changed. - meta_data: - Key-value pairs. - """ - model = onnx.load(filename) - for key, value in meta_data.items(): - meta = model.metadata_props.add() - meta.key = key - meta.value = value - - onnx.save(model, filename) - - -def export_encoder_model_onnx( - encoder_model: OnnxEncoder, - encoder_filename: str, - opset_version: int = 11, -) -> None: - """ - Onnx model inputs: - - 0: src - - many state tensors (the exact number depending on the actual model) - Onnx model outputs: - - 0: output, its shape is (N, T, joiner_dim) - - many state tensors (the exact number depending on the actual model) - Args: - encoder_model: - The model to be exported - encoder_filename: - The filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - - encoder_model.encoder.__class__.forward = ( - encoder_model.encoder.__class__.streaming_forward - ) - - decode_chunk_len = encoder_model.encoder.decode_chunk_size * 2 - pad_length = 7 - T = decode_chunk_len + pad_length - logging.info(f"decode_chunk_len: {decode_chunk_len}") - logging.info(f"pad_length: {pad_length}") - logging.info(f"T: {T}") - - x = torch.rand(1, T, 80, dtype=torch.float32) - - init_state = encoder_model.encoder.get_init_state() - - num_encoders = encoder_model.encoder.num_encoders - logging.info(f"num_encoders: {num_encoders}") - logging.info(f"len(init_state): {len(init_state)}") - - inputs = {} - input_names = ["x"] - - outputs = {} - output_names = ["encoder_out"] - - def build_inputs_outputs(tensors, name, N): - for i, s in enumerate(tensors): - logging.info(f"{name}_{i}.shape: {s.shape}") - inputs[f"{name}_{i}"] = {N: "N"} - outputs[f"new_{name}_{i}"] = {N: "N"} - input_names.append(f"{name}_{i}") - output_names.append(f"new_{name}_{i}") - - num_encoder_layers = ",".join(map(str, encoder_model.encoder.num_encoder_layers)) - encoder_dims = ",".join(map(str, encoder_model.encoder.encoder_dims)) - attention_dims = ",".join(map(str, encoder_model.encoder.attention_dims)) - cnn_module_kernels = ",".join(map(str, encoder_model.encoder.cnn_module_kernels)) - ds = encoder_model.encoder.zipformer_downsampling_factors - left_context_len = encoder_model.encoder.left_context_len - left_context_len = [left_context_len // k for k in ds] - left_context_len = ",".join(map(str, left_context_len)) - - meta_data = { - "model_type": "zipformer", - "version": "1", - "model_author": "k2-fsa", - "decode_chunk_len": str(decode_chunk_len), # 32 - "T": str(T), # 39 - "num_encoder_layers": num_encoder_layers, - "encoder_dims": encoder_dims, - "attention_dims": attention_dims, - "cnn_module_kernels": cnn_module_kernels, - "left_context_len": left_context_len, - } - logging.info(f"meta_data: {meta_data}") - - # (num_encoder_layers, 1) - cached_len = init_state[num_encoders * 0 : num_encoders * 1] - - # (num_encoder_layers, 1, encoder_dim) - cached_avg = init_state[num_encoders * 1 : num_encoders * 2] - - # (num_encoder_layers, left_context_len, 1, attention_dim) - cached_key = init_state[num_encoders * 2 : num_encoders * 3] - - # (num_encoder_layers, left_context_len, 1, attention_dim//2) - cached_val = init_state[num_encoders * 3 : num_encoders * 4] - - # (num_encoder_layers, left_context_len, 1, attention_dim//2) - cached_val2 = init_state[num_encoders * 4 : num_encoders * 5] - - # (num_encoder_layers, 1, encoder_dim, cnn_module_kernel-1) - cached_conv1 = init_state[num_encoders * 5 : num_encoders * 6] - - # (num_encoder_layers, 1, encoder_dim, cnn_module_kernel-1) - cached_conv2 = init_state[num_encoders * 6 : num_encoders * 7] - - build_inputs_outputs(cached_len, "cached_len", 1) - build_inputs_outputs(cached_avg, "cached_avg", 1) - build_inputs_outputs(cached_key, "cached_key", 2) - build_inputs_outputs(cached_val, "cached_val", 2) - build_inputs_outputs(cached_val2, "cached_val2", 2) - build_inputs_outputs(cached_conv1, "cached_conv1", 1) - build_inputs_outputs(cached_conv2, "cached_conv2", 1) - - logging.info(inputs) - logging.info(outputs) - logging.info(input_names) - logging.info(output_names) - - torch.onnx.export( - encoder_model, - (x, init_state), - encoder_filename, - verbose=False, - opset_version=opset_version, - input_names=input_names, - output_names=output_names, - dynamic_axes={ - "x": {0: "N"}, - "encoder_out": {0: "N"}, - **inputs, - **outputs, - }, - ) - - add_meta_data(filename=encoder_filename, meta_data=meta_data) - - -def export_decoder_model_onnx( - decoder_model: nn.Module, - decoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the decoder model to ONNX format. - The exported model has one input: - - y: a torch.int64 tensor of shape (N, context_size) - and has one output: - - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) - Note: The argument need_pad is fixed to False. - Args: - decoder_model: - The decoder model to be exported. - decoder_filename: - Filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - context_size = decoder_model.decoder.context_size - vocab_size = decoder_model.decoder.vocab_size - y = torch.zeros(10, context_size, dtype=torch.int64) - torch.onnx.export( - decoder_model, - y, - decoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["y"], - output_names=["decoder_out"], - dynamic_axes={ - "y": {0: "N"}, - "decoder_out": {0: "N"}, - }, - ) - meta_data = { - "context_size": str(context_size), - "vocab_size": str(vocab_size), - } - add_meta_data(filename=decoder_filename, meta_data=meta_data) - - -def export_joiner_model_onnx( - joiner_model: nn.Module, - joiner_filename: str, - opset_version: int = 11, -) -> None: - """Export the joiner model to ONNX format. - The exported joiner model has two inputs: - - encoder_out: a tensor of shape (N, joiner_dim) - - decoder_out: a tensor of shape (N, joiner_dim) - and produces one output: - - logit: a tensor of shape (N, vocab_size) - """ - joiner_dim = joiner_model.output_linear.weight.shape[1] - logging.info(f"joiner dim: {joiner_dim}") - - projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) - projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) - - torch.onnx.export( - joiner_model, - (projected_encoder_out, projected_decoder_out), - joiner_filename, - verbose=False, - opset_version=opset_version, - input_names=[ - "encoder_out", - "decoder_out", - ], - output_names=["logit"], - dynamic_axes={ - "encoder_out": {0: "N"}, - "decoder_out": {0: "N"}, - "logit": {0: "N"}, - }, - ) - meta_data = { - "joiner_dim": str(joiner_dim), - } - add_meta_data(filename=joiner_filename, meta_data=meta_data) - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - setup_logger(f"{params.exp_dir}/log-export/log-export-onnx") - - logging.info(f"device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - model.to(device) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to("cpu") - model.eval() - - convert_scaled_to_non_scaled(model, inplace=True) - encoder = OnnxEncoder( - encoder=model.encoder, - encoder_proj=model.joiner.encoder_proj, - ) - - decoder = OnnxDecoder( - decoder=model.decoder, - decoder_proj=model.joiner.decoder_proj, - ) - - joiner = OnnxJoiner(output_linear=model.joiner.output_linear) - - encoder_num_param = sum([p.numel() for p in encoder.parameters()]) - decoder_num_param = sum([p.numel() for p in decoder.parameters()]) - joiner_num_param = sum([p.numel() for p in joiner.parameters()]) - total_num_param = encoder_num_param + decoder_num_param + joiner_num_param - logging.info(f"encoder parameters: {encoder_num_param}") - logging.info(f"decoder parameters: {decoder_num_param}") - logging.info(f"joiner parameters: {joiner_num_param}") - logging.info(f"total parameters: {total_num_param}") - - if params.iter > 0: - suffix = f"iter-{params.iter}" - else: - suffix = f"epoch-{params.epoch}" - - suffix += f"-avg-{params.avg}" - if params.use_averaged_model: - suffix += "-with-averaged-model" - - opset_version = 13 - - logging.info("Exporting encoder") - encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" - export_encoder_model_onnx( - encoder, - encoder_filename, - opset_version=opset_version, - ) - logging.info(f"Exported encoder to {encoder_filename}") - - logging.info("Exporting decoder") - decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" - export_decoder_model_onnx( - decoder, - decoder_filename, - opset_version=opset_version, - ) - logging.info(f"Exported decoder to {decoder_filename}") - - logging.info("Exporting joiner") - joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" - export_joiner_model_onnx( - joiner, - joiner_filename, - opset_version=opset_version, - ) - logging.info(f"Exported joiner to {joiner_filename}") - - # Generate int8 quantization models - # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection - - logging.info("Generate int8 quantization models") - - encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" - quantize_dynamic( - model_input=encoder_filename, - model_output=encoder_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) - - decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" - quantize_dynamic( - model_input=decoder_filename, - model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) - - joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" - quantize_dynamic( - model_input=joiner_filename, - model_output=joiner_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) - - -if __name__ == "__main__": - main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-onnx.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-onnx.py new file mode 120000 index 000000000..57a0cd0a0 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-onnx.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export.py deleted file mode 100755 index 5735ee692..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export.py +++ /dev/null @@ -1,878 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" - -Usage: - -(1) Export to torchscript model using torch.jit.script() - -./pruned_transducer_stateless7_streaming/export.py \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 30 \ - --avg 9 \ - --jit 1 - -It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later -load it by `torch.jit.load("cpu_jit.pt")`. - -Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python -are on CPU. You can use `to("cuda")` to move them to a CUDA device. - -Check -https://github.com/k2-fsa/sherpa -for how to use the exported models outside of icefall. - -(2) Export `model.state_dict()` - -./pruned_transducer_stateless7_streaming/export.py \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 20 \ - --avg 10 - -It will generate a file `pretrained.pt` in the given `exp_dir`. You can later -load it by `icefall.checkpoint.load_checkpoint()`. - -To use the generated file with `pruned_transducer_stateless7_streaming/decode.py`, -you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/librispeech/ASR - ./pruned_transducer_stateless7_streaming/decode.py \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 \ - --decoding-method greedy_search \ - --bpe-model data/lang_bpe_500/bpe.model - -Check ./pretrained.py for its usage. - -Note: If you don't want to train a model from scratch, we have -provided one for you. You can get it at - -https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 - -with the following commands: - - sudo apt-get install git-lfs - git lfs install - git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 - # You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp - -(3) Export to ONNX format with pretrained.pt - -cd ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp -ln -s pretrained.pt epoch-999.pt -./pruned_transducer_stateless7_streaming/export.py \ - --exp-dir ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --use-averaged-model False \ - --epoch 999 \ - --avg 1 \ - --fp16 \ - --onnx 1 - -It will generate the following files in the given `exp_dir`. -Check `onnx_check.py` for how to use them. - - - encoder.onnx - - decoder.onnx - - joiner.onnx - - joiner_encoder_proj.onnx - - joiner_decoder_proj.onnx - -Check -https://github.com/k2-fsa/sherpa-onnx -for how to use the exported models outside of icefall. - -(4) Export to ONNX format for triton server - -cd ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp -ln -s pretrained.pt epoch-999.pt -./pruned_transducer_stateless7_streaming/export.py \ - --exp-dir ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --use-averaged-model False \ - --epoch 999 \ - --avg 1 \ - --fp16 \ - --onnx-triton 1 \ - --onnx 1 - -It will generate the following files in the given `exp_dir`. -Check `onnx_check.py` for how to use them. - - - encoder.onnx - - decoder.onnx - - joiner.onnx - -Check -https://github.com/k2-fsa/sherpa/tree/master/triton -for how to use the exported models outside of icefall. - -""" - - -import argparse -import logging -from pathlib import Path - -import onnxruntime -import sentencepiece as spm -import torch -import torch.nn as nn -from onnx_model_wrapper import OnnxStreamingEncoder, TritonOnnxDecoder, TritonOnnxJoiner -from scaling_converter import convert_scaled_to_non_scaled -from train import add_model_arguments, get_params, get_transducer_model -from zipformer import stack_states - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7_streaming/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - It will generate a file named cpu_jit.pt - - Check ./jit_pretrained.py for how to use it. - """, - ) - - parser.add_argument( - "--onnx", - type=str2bool, - default=False, - help="""If True, --jit is ignored and it exports the model - to onnx format. It will generate the following files: - - - encoder.onnx - - decoder.onnx - - joiner.onnx - - joiner_encoder_proj.onnx - - joiner_decoder_proj.onnx - - Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them. - """, - ) - - parser.add_argument( - "--onnx-triton", - type=str2bool, - default=False, - help="""If True, --onnx would export model into the following files: - - - encoder.onnx - - decoder.onnx - - joiner.onnx - These files would be used for https://github.com/k2-fsa/sherpa/tree/master/triton. - """, - ) - - parser.add_argument( - "--fp16", - action="store_true", - help="whether to export fp16 onnx model, default false", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -def test_acc(xlist, blist, rtol=1e-3, atol=1e-5, tolerate_small_mismatch=True): - for a, b in zip(xlist, blist): - try: - torch.testing.assert_allclose(a, b, rtol=rtol, atol=atol) - except AssertionError as error: - if tolerate_small_mismatch: - print("small mismatch detected", error) - else: - return False - return True - - -def export_encoder_model_onnx( - encoder_model: nn.Module, - encoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the given encoder model to ONNX format. - The exported model has two inputs: - - - x, a tensor of shape (N, T, C); dtype is torch.float32 - - x_lens, a tensor of shape (N,); dtype is torch.int64 - - and it has two outputs: - - - encoder_out, a tensor of shape (N, T, C) - - encoder_out_lens, a tensor of shape (N,) - - Note: The warmup argument is fixed to 1. - - Args: - encoder_model: - The input encoder model - encoder_filename: - The filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - batch_size = 17 - seq_len = 101 - torch.manual_seed(0) - x = torch.rand(batch_size, seq_len, 80, dtype=torch.float32) - x_lens = torch.tensor([seq_len - i for i in range(batch_size)], dtype=torch.int64) - - # encoder_model = torch.jit.script(encoder_model) - # It throws the following error for the above statement - # - # RuntimeError: Exporting the operator __is_ to ONNX opset version - # 11 is not supported. Please feel free to request support or - # submit a pull request on PyTorch GitHub. - # - # I cannot find which statement causes the above error. - # torch.onnx.export() will use torch.jit.trace() internally, which - # works well for the current reworked model - initial_states = [encoder_model.get_init_state() for _ in range(batch_size)] - states = stack_states(initial_states) - - left_context_len = encoder_model.decode_chunk_size * encoder_model.num_left_chunks - encoder_attention_dim = encoder_model.encoders[0].attention_dim - - len_cache = torch.cat(states[: encoder_model.num_encoders]).transpose(0, 1) # B,15 - avg_cache = torch.cat( - states[encoder_model.num_encoders : 2 * encoder_model.num_encoders] - ).transpose( - 0, 1 - ) # [B,15,384] - cnn_cache = torch.cat(states[5 * encoder_model.num_encoders :]).transpose( - 0, 1 - ) # [B,2*15,384,cnn_kernel-1] - pad_tensors = [ - torch.nn.functional.pad( - tensor, - ( - 0, - encoder_attention_dim - tensor.shape[-1], - 0, - 0, - 0, - left_context_len - tensor.shape[1], - 0, - 0, - ), - ) - for tensor in states[ - 2 * encoder_model.num_encoders : 5 * encoder_model.num_encoders - ] - ] - attn_cache = torch.cat(pad_tensors).transpose(0, 2) # [B,64,15*3,192] - - encoder_model_wrapper = OnnxStreamingEncoder(encoder_model) - - torch.onnx.export( - encoder_model_wrapper, - (x, x_lens, len_cache, avg_cache, attn_cache, cnn_cache), - encoder_filename, - verbose=False, - opset_version=opset_version, - input_names=[ - "x", - "x_lens", - "len_cache", - "avg_cache", - "attn_cache", - "cnn_cache", - ], - output_names=[ - "encoder_out", - "encoder_out_lens", - "new_len_cache", - "new_avg_cache", - "new_attn_cache", - "new_cnn_cache", - ], - dynamic_axes={ - "x": {0: "N", 1: "T"}, - "x_lens": {0: "N"}, - "encoder_out": {0: "N", 1: "T"}, - "encoder_out_lens": {0: "N"}, - "len_cache": {0: "N"}, - "avg_cache": {0: "N"}, - "attn_cache": {0: "N"}, - "cnn_cache": {0: "N"}, - "new_len_cache": {0: "N"}, - "new_avg_cache": {0: "N"}, - "new_attn_cache": {0: "N"}, - "new_cnn_cache": {0: "N"}, - }, - ) - logging.info(f"Saved to {encoder_filename}") - - # Test onnx encoder with torch native encoder - encoder_model.eval() - ( - encoder_out_torch, - encoder_out_lens_torch, - new_states_torch, - ) = encoder_model.streaming_forward( - x=x, - x_lens=x_lens, - states=states, - ) - ort_session = onnxruntime.InferenceSession( - str(encoder_filename), providers=["CPUExecutionProvider"] - ) - ort_inputs = { - "x": x.numpy(), - "x_lens": x_lens.numpy(), - "len_cache": len_cache.numpy(), - "avg_cache": avg_cache.numpy(), - "attn_cache": attn_cache.numpy(), - "cnn_cache": cnn_cache.numpy(), - } - ort_outs = ort_session.run(None, ort_inputs) - - assert test_acc( - [encoder_out_torch.numpy(), encoder_out_lens_torch.numpy()], ort_outs[:2] - ) - logging.info(f"{encoder_filename} acc test succeeded.") - - -def export_decoder_model_onnx( - decoder_model: nn.Module, - decoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the decoder model to ONNX format. - - The exported model has one input: - - - y: a torch.int64 tensor of shape (N, decoder_model.context_size) - - and has one output: - - - decoder_out: a torch.float32 tensor of shape (N, 1, C) - - Note: The argument need_pad is fixed to False. - - Args: - decoder_model: - The decoder model to be exported. - decoder_filename: - Filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) - need_pad = False # Always False, so we can use torch.jit.trace() here - # Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script() - # in this case - torch.onnx.export( - decoder_model, - (y, need_pad), - decoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["y", "need_pad"], - output_names=["decoder_out"], - dynamic_axes={ - "y": {0: "N"}, - "decoder_out": {0: "N"}, - }, - ) - logging.info(f"Saved to {decoder_filename}") - - -def export_decoder_model_onnx_triton( - decoder_model: nn.Module, - decoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the decoder model to ONNX format. - - The exported model has one input: - - - y: a torch.int64 tensor of shape (N, decoder_model.context_size) - - and has one output: - - - decoder_out: a torch.float32 tensor of shape (N, 1, C) - - Note: The argument need_pad is fixed to False. - - Args: - decoder_model: - The decoder model to be exported. - decoder_filename: - Filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) - - decoder_model = TritonOnnxDecoder(decoder_model) - - torch.onnx.export( - decoder_model, - (y,), - decoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["y"], - output_names=["decoder_out"], - dynamic_axes={ - "y": {0: "N"}, - "decoder_out": {0: "N"}, - }, - ) - logging.info(f"Saved to {decoder_filename}") - - -def export_joiner_model_onnx( - joiner_model: nn.Module, - joiner_filename: str, - opset_version: int = 11, -) -> None: - """Export the joiner model to ONNX format. - The exported joiner model has two inputs: - - - projected_encoder_out: a tensor of shape (N, joiner_dim) - - projected_decoder_out: a tensor of shape (N, joiner_dim) - - and produces one output: - - - logit: a tensor of shape (N, vocab_size) - - The exported encoder_proj model has one input: - - - encoder_out: a tensor of shape (N, encoder_out_dim) - - and produces one output: - - - projected_encoder_out: a tensor of shape (N, joiner_dim) - - The exported decoder_proj model has one input: - - - decoder_out: a tensor of shape (N, decoder_out_dim) - - and produces one output: - - - projected_decoder_out: a tensor of shape (N, joiner_dim) - """ - encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") - decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") - - encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] - decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] - joiner_dim = joiner_model.decoder_proj.weight.shape[0] - - projected_encoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32) - projected_decoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32) - - project_input = False - # Note: It uses torch.jit.trace() internally - torch.onnx.export( - joiner_model, - (projected_encoder_out, projected_decoder_out, project_input), - joiner_filename, - verbose=False, - opset_version=opset_version, - input_names=[ - "encoder_out", - "decoder_out", - "project_input", - ], - output_names=["logit"], - dynamic_axes={ - "encoder_out": {0: "N"}, - "decoder_out": {0: "N"}, - "logit": {0: "N"}, - }, - ) - logging.info(f"Saved to {joiner_filename}") - - encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) - torch.onnx.export( - joiner_model.encoder_proj, - encoder_out, - encoder_proj_filename, - verbose=False, - opset_version=opset_version, - input_names=["encoder_out"], - output_names=["projected_encoder_out"], - dynamic_axes={ - "encoder_out": {0: "N"}, - "projected_encoder_out": {0: "N"}, - }, - ) - logging.info(f"Saved to {encoder_proj_filename}") - - decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) - torch.onnx.export( - joiner_model.decoder_proj, - decoder_out, - decoder_proj_filename, - verbose=False, - opset_version=opset_version, - input_names=["decoder_out"], - output_names=["projected_decoder_out"], - dynamic_axes={ - "decoder_out": {0: "N"}, - "projected_decoder_out": {0: "N"}, - }, - ) - logging.info(f"Saved to {decoder_proj_filename}") - - -def export_joiner_model_onnx_triton( - joiner_model: nn.Module, - joiner_filename: str, - opset_version: int = 11, -) -> None: - """Export the joiner model to ONNX format. - The exported model has two inputs: - - encoder_out: a tensor of shape (N, encoder_out_dim) - - decoder_out: a tensor of shape (N, decoder_out_dim) - and has one output: - - joiner_out: a tensor of shape (N, vocab_size) - Note: The argument project_input is fixed to True. A user should not - project the encoder_out/decoder_out by himself/herself. The exported joiner - will do that for the user. - """ - encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] - decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] - encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) - decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) - - joiner_model = TritonOnnxJoiner(joiner_model) - # Note: It uses torch.jit.trace() internally - torch.onnx.export( - joiner_model, - (encoder_out, decoder_out), - joiner_filename, - verbose=False, - opset_version=opset_version, - input_names=["encoder_out", "decoder_out"], - output_names=["logit"], - dynamic_axes={ - "encoder_out": {0: "N"}, - "decoder_out": {0: "N"}, - "logit": {0: "N"}, - }, - ) - logging.info(f"Saved to {joiner_filename}") - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - model.to(device) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to("cpu") - model.eval() - - if params.onnx: - convert_scaled_to_non_scaled(model, inplace=True) - opset_version = 13 - logging.info("Exporting to onnx format") - encoder_filename = params.exp_dir / "encoder.onnx" - export_encoder_model_onnx( - model.encoder, - encoder_filename, - opset_version=opset_version, - ) - if not params.onnx_triton: - decoder_filename = params.exp_dir / "decoder.onnx" - export_decoder_model_onnx( - model.decoder, - decoder_filename, - opset_version=opset_version, - ) - - joiner_filename = params.exp_dir / "joiner.onnx" - export_joiner_model_onnx( - model.joiner, - joiner_filename, - opset_version=opset_version, - ) - else: - decoder_filename = params.exp_dir / "decoder.onnx" - export_decoder_model_onnx_triton( - model.decoder, - decoder_filename, - opset_version=opset_version, - ) - - joiner_filename = params.exp_dir / "joiner.onnx" - export_joiner_model_onnx_triton( - model.joiner, - joiner_filename, - opset_version=opset_version, - ) - - if params.fp16: - try: - import onnxmltools - from onnxmltools.utils.float16_converter import convert_float_to_float16 - except ImportError: - print("Please install onnxmltools!") - import sys - - sys.exit(1) - - def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path): - onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path) - onnx_fp16_model = convert_float_to_float16(onnx_fp32_model) - onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path) - - encoder_fp16_filename = params.exp_dir / "encoder_fp16.onnx" - export_onnx_fp16(encoder_filename, encoder_fp16_filename) - - decoder_fp16_filename = params.exp_dir / "decoder_fp16.onnx" - export_onnx_fp16(decoder_filename, decoder_fp16_filename) - - joiner_fp16_filename = params.exp_dir / "joiner_fp16.onnx" - export_onnx_fp16(joiner_filename, joiner_fp16_filename) - - if not params.onnx_triton: - encoder_proj_filename = str(joiner_filename).replace( - ".onnx", "_encoder_proj.onnx" - ) - encoder_proj_fp16_filename = ( - params.exp_dir / "joiner_encoder_proj_fp16.onnx" - ) - export_onnx_fp16(encoder_proj_filename, encoder_proj_fp16_filename) - - decoder_proj_filename = str(joiner_filename).replace( - ".onnx", "_decoder_proj.onnx" - ) - decoder_proj_fp16_filename = ( - params.exp_dir / "joiner_decoder_proj_fp16.onnx" - ) - export_onnx_fp16(decoder_proj_filename, decoder_proj_fp16_filename) - - elif params.jit: - convert_scaled_to_non_scaled(model, inplace=True) - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - model.encoder.__class__.forward = model.encoder.__class__.streaming_forward - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - filename = params.exp_dir / "cpu_jit.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torchscript. Export model.state_dict()") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export.py new file mode 120000 index 000000000..2acafdc61 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py deleted file mode 100755 index 4fd5e1820..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py +++ /dev/null @@ -1,278 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads torchscript models, exported by `torch.jit.script()` -and uses them to decode waves. -You can use the following command to get the exported models: - -./pruned_transducer_stateless7_streaming/export.py \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 20 \ - --avg 10 \ - --jit 1 - -Usage of this script: - -./pruned_transducer_stateless7_streaming/jit_pretrained.py \ - --nn-model-filename ./pruned_transducer_stateless7_streaming/exp/cpu_jit.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - /path/to/foo.wav \ - /path/to/bar.wav -""" - -import argparse -import logging -import math -from typing import List - -import kaldifeat -import sentencepiece as spm -import torch -import torchaudio -from torch.nn.utils.rnn import pad_sequence - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--nn-model-filename", - type=str, - required=True, - help="Path to the torchscript model cpu_jit.pt", - ) - - parser.add_argument( - "--bpe-model", - type=str, - help="""Path to bpe.model.""", - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--decode-chunk-len", - type=int, - default=32, - help="The chunk size for decoding (in frames before subsampling)", - ) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float = 16000 -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -def greedy_search( - model: torch.jit.ScriptModule, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, -) -> List[List[int]]: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - Args: - model: - The transducer model. - encoder_out: - A 3-D tensor of shape (N, T, C) - encoder_out_lens: - A 1-D tensor of shape (N,). - Returns: - Return the decoded results for each utterance. - """ - assert encoder_out.ndim == 3 - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - device = encoder_out.device - blank_id = 0 # hard-code to 0 - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - context_size = model.decoder.context_size - hyps = [[blank_id] * context_size for _ in range(N)] - - decoder_input = torch.tensor( - hyps, - device=device, - dtype=torch.int64, - ) # (N, context_size) - - decoder_out = model.decoder( - decoder_input, - need_pad=torch.tensor([False]), - ).squeeze(1) - - offset = 0 - for batch_size in batch_size_list: - start = offset - end = offset + batch_size - current_encoder_out = packed_encoder_out.data[start:end] - current_encoder_out = current_encoder_out - # current_encoder_out's shape: (batch_size, encoder_out_dim) - offset = end - - decoder_out = decoder_out[:batch_size] - - logits = model.joiner( - current_encoder_out, - decoder_out, - ) - # logits'shape (batch_size, vocab_size) - - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v != blank_id: - hyps[i].append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = [h[-context_size:] for h in hyps[:batch_size]] - decoder_input = torch.tensor( - decoder_input, - device=device, - dtype=torch.int64, - ) - decoder_out = model.decoder( - decoder_input, - need_pad=torch.tensor([False]), - ) - decoder_out = decoder_out.squeeze(1) - - sorted_ans = [h[context_size:] for h in hyps] - ans = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - - device = torch.device("cpu") - - logging.info(f"device: {device}") - - model = torch.jit.load(args.nn_model_filename) - model.encoder.decode_chunk_size = args.decode_chunk_len // 2 - - model.eval() - - model.to(device) - - sp = spm.SentencePieceProcessor() - sp.load(args.bpe_model) - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = 16000 - opts.mel_opts.num_bins = 80 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {args.sound_files}") - waves = read_sound_files( - filenames=args.sound_files, - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence( - features, - batch_first=True, - padding_value=math.log(1e-10), - ) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - encoder_out, encoder_out_lens = model.encoder( - x=features, - x_lens=feature_lengths, - ) - - hyps = greedy_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - s = "\n" - for filename, hyp in zip(args.sound_files, hyps): - words = sp.decode(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py new file mode 120000 index 000000000..5d9c6ba00 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py deleted file mode 100755 index a164f3f69..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py +++ /dev/null @@ -1,313 +0,0 @@ -#!/usr/bin/env python3 - -""" -Usage: -./pruned_transducer_stateless7_streaming/jit_trace_export.py \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 30 \ - --avg 10 \ - --use-averaged-model=True \ - --decode-chunk-len 32 -""" - -import argparse -import logging -from pathlib import Path - -import sentencepiece as spm -import torch -from scaling_converter import convert_scaled_to_non_scaled -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import AttributeDict, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for averaging. - Note: Epoch counts from 0. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless2/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - add_model_arguments(parser) - - return parser - - -def export_encoder_model_jit_trace( - encoder_model: torch.nn.Module, - encoder_filename: str, - params: AttributeDict, -) -> None: - """Export the given encoder model with torch.jit.trace() - - Note: The warmup argument is fixed to 1. - - Args: - encoder_model: - The input encoder model - encoder_filename: - The filename to save the exported model. - """ - decode_chunk_len = params.decode_chunk_len # before subsampling - pad_length = 7 - s = f"decode_chunk_len: {decode_chunk_len}" - logging.info(s) - assert encoder_model.decode_chunk_size == decode_chunk_len // 2, ( - encoder_model.decode_chunk_size, - decode_chunk_len, - ) - - T = decode_chunk_len + pad_length - - x = torch.zeros(1, T, 80, dtype=torch.float32) - x_lens = torch.full((1,), T, dtype=torch.int32) - states = encoder_model.get_init_state(device=x.device) - - encoder_model.__class__.forward = encoder_model.__class__.streaming_forward - traced_model = torch.jit.trace(encoder_model, (x, x_lens, states)) - traced_model.save(encoder_filename) - logging.info(f"Saved to {encoder_filename}") - - -def export_decoder_model_jit_trace( - decoder_model: torch.nn.Module, - decoder_filename: str, -) -> None: - """Export the given decoder model with torch.jit.trace() - - Note: The argument need_pad is fixed to False. - - Args: - decoder_model: - The input decoder model - decoder_filename: - The filename to save the exported model. - """ - y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) - need_pad = torch.tensor([False]) - - traced_model = torch.jit.trace(decoder_model, (y, need_pad)) - traced_model.save(decoder_filename) - logging.info(f"Saved to {decoder_filename}") - - -def export_joiner_model_jit_trace( - joiner_model: torch.nn.Module, - joiner_filename: str, -) -> None: - """Export the given joiner model with torch.jit.trace() - - Note: The argument project_input is fixed to True. A user should not - project the encoder_out/decoder_out by himself/herself. The exported joiner - will do that for the user. - - Args: - joiner_model: - The input joiner model - joiner_filename: - The filename to save the exported model. - - """ - encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] - decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] - encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) - decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) - - traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out)) - traced_model.save(joiner_filename) - logging.info(f"Saved to {joiner_filename}") - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - - logging.info(f"device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to("cpu") - model.eval() - - convert_scaled_to_non_scaled(model, inplace=True) - logging.info("Using torch.jit.trace()") - - logging.info("Exporting encoder") - encoder_filename = params.exp_dir / "encoder_jit_trace.pt" - export_encoder_model_jit_trace(model.encoder, encoder_filename, params) - - logging.info("Exporting decoder") - decoder_filename = params.exp_dir / "decoder_jit_trace.pt" - export_decoder_model_jit_trace(model.decoder, decoder_filename) - - logging.info("Exporting joiner") - joiner_filename = params.exp_dir / "joiner_jit_trace.pt" - export_joiner_model_jit_trace(model.joiner, joiner_filename) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py new file mode 120000 index 000000000..457131699 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py deleted file mode 100755 index f2ac1914d..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py +++ /dev/null @@ -1,295 +0,0 @@ -#!/usr/bin/env python3 -# flake8: noqa -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads torchscript models exported by `torch.jit.trace()` -and uses them to decode waves. -You can use the following command to get the exported models: - -./pruned_transducer_stateless7_streaming/jit_trace_export.py \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 30 \ - --avg 10 \ - --use-averaged-model=True \ - --decode-chunk-len 32 - -Usage of this script: - -./pruned_transducer_stateless7_streaming/jit_trace_pretrained.py \ - --encoder-model-filename ./pruned_transducer_stateless7_streaming/exp/encoder_jit_trace.pt \ - --decoder-model-filename ./pruned_transducer_stateless7_streaming/exp/decoder_jit_trace.pt \ - --joiner-model-filename ./pruned_transducer_stateless7_streaming/exp/joiner_jit_trace.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --decode-chunk-len 32 \ - /path/to/foo.wav \ -""" - -import argparse -import logging -import math -from typing import List, Optional - -import kaldifeat -import sentencepiece as spm -import torch -import torchaudio -from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature -from torch.nn.utils.rnn import pad_sequence - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--encoder-model-filename", - type=str, - required=True, - help="Path to the encoder torchscript model. ", - ) - - parser.add_argument( - "--decoder-model-filename", - type=str, - required=True, - help="Path to the decoder torchscript model. ", - ) - - parser.add_argument( - "--joiner-model-filename", - type=str, - required=True, - help="Path to the joiner torchscript model. ", - ) - - parser.add_argument( - "--bpe-model", - type=str, - help="""Path to bpe.model.""", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--decode-chunk-len", - type=int, - default=32, - help="The chunk size for decoding (in frames before subsampling)", - ) - - parser.add_argument( - "sound_file", - type=str, - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -def greedy_search( - decoder: torch.jit.ScriptModule, - joiner: torch.jit.ScriptModule, - encoder_out: torch.Tensor, - decoder_out: Optional[torch.Tensor] = None, - hyp: Optional[List[int]] = None, -): - assert encoder_out.ndim == 2 - context_size = 2 - blank_id = 0 - - if decoder_out is None: - assert hyp is None, hyp - hyp = [blank_id] * context_size - decoder_input = torch.tensor(hyp, dtype=torch.int32).unsqueeze(0) - # decoder_input.shape (1,, 1 context_size) - decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1) - else: - assert decoder_out.ndim == 2 - assert hyp is not None, hyp - - T = encoder_out.size(0) - for i in range(T): - cur_encoder_out = encoder_out[i : i + 1] - joiner_out = joiner(cur_encoder_out, decoder_out).squeeze(0) - y = joiner_out.argmax(dim=0).item() - - if y != blank_id: - hyp.append(y) - decoder_input = hyp[-context_size:] - - decoder_input = torch.tensor(decoder_input, dtype=torch.int32).unsqueeze(0) - decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1) - - return hyp, decoder_out - - -def create_streaming_feature_extractor(sample_rate) -> OnlineFeature: - """Create a CPU streaming feature extractor. - - At present, we assume it returns a fbank feature extractor with - fixed options. In the future, we will support passing in the options - from outside. - - Returns: - Return a CPU streaming feature extractor. - """ - opts = FbankOptions() - opts.device = "cpu" - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = sample_rate - opts.mel_opts.num_bins = 80 - return OnlineFbank(opts) - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - - device = torch.device("cpu") - - logging.info(f"device: {device}") - - encoder = torch.jit.load(args.encoder_model_filename) - decoder = torch.jit.load(args.decoder_model_filename) - joiner = torch.jit.load(args.joiner_model_filename) - - encoder.eval() - decoder.eval() - joiner.eval() - - encoder.to(device) - decoder.to(device) - joiner.to(device) - - sp = spm.SentencePieceProcessor() - sp.load(args.bpe_model) - - logging.info("Constructing Fbank computer") - online_fbank = create_streaming_feature_extractor(args.sample_rate) - - logging.info(f"Reading sound files: {args.sound_file}") - wave_samples = read_sound_files( - filenames=[args.sound_file], - expected_sample_rate=args.sample_rate, - )[0] - logging.info(wave_samples.shape) - - logging.info("Decoding started") - chunk_length = args.decode_chunk_len - assert encoder.decode_chunk_size == chunk_length // 2, ( - encoder.decode_chunk_size, - chunk_length, - ) - - # we subsample features with ((x_len - 7) // 2 + 1) // 2 - pad_length = 7 - T = chunk_length + pad_length - - logging.info(f"chunk_length: {chunk_length}") - - states = encoder.get_init_state(device) - - tail_padding = torch.zeros(int(0.3 * args.sample_rate), dtype=torch.float32) - - wave_samples = torch.cat([wave_samples, tail_padding]) - - chunk = int(0.25 * args.sample_rate) # 0.2 second - num_processed_frames = 0 - - hyp = None - decoder_out = None - - start = 0 - while start < wave_samples.numel(): - logging.info(f"{start}/{wave_samples.numel()}") - end = min(start + chunk, wave_samples.numel()) - samples = wave_samples[start:end] - start += chunk - online_fbank.accept_waveform( - sampling_rate=args.sample_rate, - waveform=samples, - ) - while online_fbank.num_frames_ready - num_processed_frames >= T: - frames = [] - for i in range(T): - frames.append(online_fbank.get_frame(num_processed_frames + i)) - frames = torch.cat(frames, dim=0).unsqueeze(0) - x_lens = torch.tensor([T], dtype=torch.int32) - encoder_out, out_lens, states = encoder( - x=frames, - x_lens=x_lens, - states=states, - ) - num_processed_frames += chunk_length - - hyp, decoder_out = greedy_search( - decoder, joiner, encoder_out.squeeze(0), decoder_out, hyp - ) - - context_size = 2 - logging.info(args.sound_file) - logging.info(sp.decode(hyp[context_size:])) - - logging.info("Decoding Done") - - -torch.set_num_threads(4) -torch.set_num_interop_threads(1) -torch._C._jit_set_profiling_executor(False) -torch._C._jit_set_profiling_mode(False) -torch._C._set_graph_executor_optimize(False) -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py new file mode 120000 index 000000000..2b8fa3cbb --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_check.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_check.py deleted file mode 100755 index d7a4b9551..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_check.py +++ /dev/null @@ -1,260 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) - -""" -This script checks that exported ONNX models produce the same output -with the given torchscript model for the same input. - -We use the pre-trained model from -https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 -as an example to show how to use this file. - -1. Download the pre-trained model - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" -git lfs pull --include "exp/pretrained.pt" -cd exp -ln -s pretrained.pt epoch-99.pt -popd - -2. Export the model via torch.jit.trace() - -./pruned_transducer_stateless7_streaming/jit_trace_export.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --use-averaged-model 0 \ - --epoch 99 \ - --avg 1 \ - --decode-chunk-len 32 \ - --exp-dir $repo/exp/ - -It will generate the following 3 files inside $repo/exp - - - encoder_jit_trace.pt - - decoder_jit_trace.pt - - joiner_jit_trace.pt - -3. Export the model to ONNX - -./pruned_transducer_stateless7_streaming/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --use-averaged-model 0 \ - --epoch 99 \ - --avg 1 \ - --decode-chunk-len 32 \ - --exp-dir $repo/exp/ - -It will generate the following 3 files inside $repo/exp: - - - encoder-epoch-99-avg-1.onnx - - decoder-epoch-99-avg-1.onnx - - joiner-epoch-99-avg-1.onnx - -4. Run this file - -./pruned_transducer_stateless7_streaming/onnx_check.py \ - --jit-encoder-filename $repo/exp/encoder_jit_trace.pt \ - --jit-decoder-filename $repo/exp/decoder_jit_trace.pt \ - --jit-joiner-filename $repo/exp/joiner_jit_trace.pt \ - --onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ - --onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ - --onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx -""" - -import argparse -import logging - -import torch -from onnx_pretrained import OnnxModel -from zipformer import stack_states - -from icefall import is_module_available - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--jit-encoder-filename", - required=True, - type=str, - help="Path to the torchscript encoder model", - ) - - parser.add_argument( - "--jit-decoder-filename", - required=True, - type=str, - help="Path to the torchscript decoder model", - ) - - parser.add_argument( - "--jit-joiner-filename", - required=True, - type=str, - help="Path to the torchscript joiner model", - ) - - parser.add_argument( - "--onnx-encoder-filename", - required=True, - type=str, - help="Path to the ONNX encoder model", - ) - - parser.add_argument( - "--onnx-decoder-filename", - required=True, - type=str, - help="Path to the ONNX decoder model", - ) - - parser.add_argument( - "--onnx-joiner-filename", - required=True, - type=str, - help="Path to the ONNX joiner model", - ) - - return parser - - -def test_encoder( - torch_encoder_model: torch.jit.ScriptModule, - torch_encoder_proj_model: torch.jit.ScriptModule, - onnx_model: OnnxModel, -): - N = torch.randint(1, 100, size=(1,)).item() - T = onnx_model.segment - C = 80 - x_lens = torch.tensor([T] * N) - torch_states = [torch_encoder_model.get_init_state() for _ in range(N)] - torch_states = stack_states(torch_states) - - onnx_model.init_encoder_states(N) - - for i in range(5): - logging.info(f"test_encoder: iter {i}") - x = torch.rand(N, T, C) - torch_encoder_out, _, torch_states = torch_encoder_model( - x, x_lens, torch_states - ) - torch_encoder_out = torch_encoder_proj_model(torch_encoder_out) - - onnx_encoder_out = onnx_model.run_encoder(x) - - assert torch.allclose(torch_encoder_out, onnx_encoder_out, atol=1e-4), ( - (torch_encoder_out - onnx_encoder_out).abs().max() - ) - - -def test_decoder( - torch_decoder_model: torch.jit.ScriptModule, - torch_decoder_proj_model: torch.jit.ScriptModule, - onnx_model: OnnxModel, -): - context_size = onnx_model.context_size - vocab_size = onnx_model.vocab_size - for i in range(10): - N = torch.randint(1, 100, size=(1,)).item() - logging.info(f"test_decoder: iter {i}, N={N}") - x = torch.randint( - low=1, - high=vocab_size, - size=(N, context_size), - dtype=torch.int64, - ) - torch_decoder_out = torch_decoder_model(x, need_pad=torch.tensor([False])) - torch_decoder_out = torch_decoder_proj_model(torch_decoder_out) - torch_decoder_out = torch_decoder_out.squeeze(1) - - onnx_decoder_out = onnx_model.run_decoder(x) - assert torch.allclose(torch_decoder_out, onnx_decoder_out, atol=1e-4), ( - (torch_decoder_out - onnx_decoder_out).abs().max() - ) - - -def test_joiner( - torch_joiner_model: torch.jit.ScriptModule, - onnx_model: OnnxModel, -): - encoder_dim = torch_joiner_model.encoder_proj.weight.shape[1] - decoder_dim = torch_joiner_model.decoder_proj.weight.shape[1] - for i in range(10): - N = torch.randint(1, 100, size=(1,)).item() - logging.info(f"test_joiner: iter {i}, N={N}") - encoder_out = torch.rand(N, encoder_dim) - decoder_out = torch.rand(N, decoder_dim) - - projected_encoder_out = torch_joiner_model.encoder_proj(encoder_out) - projected_decoder_out = torch_joiner_model.decoder_proj(decoder_out) - - torch_joiner_out = torch_joiner_model(encoder_out, decoder_out) - onnx_joiner_out = onnx_model.run_joiner( - projected_encoder_out, projected_decoder_out - ) - - assert torch.allclose(torch_joiner_out, onnx_joiner_out, atol=1e-4), ( - (torch_joiner_out - onnx_joiner_out).abs().max() - ) - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - logging.info(vars(args)) - - torch_encoder_model = torch.jit.load(args.jit_encoder_filename) - torch_decoder_model = torch.jit.load(args.jit_decoder_filename) - torch_joiner_model = torch.jit.load(args.jit_joiner_filename) - - onnx_model = OnnxModel( - encoder_model_filename=args.onnx_encoder_filename, - decoder_model_filename=args.onnx_decoder_filename, - joiner_model_filename=args.onnx_joiner_filename, - ) - - logging.info("Test encoder") - # When exporting the model to onnx, we have already put the encoder_proj - # inside the encoder. - test_encoder(torch_encoder_model, torch_joiner_model.encoder_proj, onnx_model) - - logging.info("Test decoder") - # When exporting the model to onnx, we have already put the decoder_proj - # inside the decoder. - test_decoder(torch_decoder_model, torch_joiner_model.decoder_proj, onnx_model) - - logging.info("Test joiner") - test_joiner(torch_joiner_model, onnx_model) - - logging.info("Finished checking ONNX models") - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -# See https://github.com/pytorch/pytorch/issues/38342 -# and https://github.com/pytorch/pytorch/issues/33354 -# -# If we don't do this, the delay increases whenever there is -# a new request that changes the actual batch size. -# If you use `py-spy dump --pid --native`, you will -# see a lot of time is spent in re-compiling the torch script model. -torch._C._jit_set_profiling_executor(False) -torch._C._jit_set_profiling_mode(False) -torch._C._set_graph_executor_optimize(False) -if __name__ == "__main__": - torch.manual_seed(20230207) - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_check.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_check.py new file mode 120000 index 000000000..28bf7bb82 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_check.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py deleted file mode 100644 index 71a418742..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py +++ /dev/null @@ -1,231 +0,0 @@ -from typing import Optional, Tuple - -import torch - - -class OnnxStreamingEncoder(torch.nn.Module): - """This class warps the streaming Zipformer to reduce the number of - state tensors for onnx. - https://github.com/k2-fsa/icefall/pull/831 - """ - - def __init__(self, encoder): - """ - Args: - encoder: An instance of Zipformer Class - """ - super().__init__() - self.model = encoder - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - len_cache: torch.tensor, - avg_cache: torch.tensor, - attn_cache: torch.tensor, - cnn_cache: torch.tensor, - ) -> Tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - ]: - """ - Args: - x: - The input tensor. Its shape is (batch_size, seq_len, feature_dim). - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - `x` before padding. - len_cache: - The cached numbers of past frames. - avg_cache: - The cached average tensors. - attn_cache: - The cached key tensors of the first attention modules. - The cached value tensors of the first attention modules. - The cached value tensors of the second attention modules. - cnn_cache: - The cached left contexts of the first convolution modules. - The cached left contexts of the second convolution modules. - - Returns: - Return a tuple containing 2 tensors: - - """ - num_encoder_layers = [] - encoder_attention_dims = [] - states = [] - for i, encoder in enumerate(self.model.encoders): - num_encoder_layers.append(encoder.num_layers) - encoder_attention_dims.append(encoder.attention_dim) - - len_cache = len_cache.transpose(0, 1) # sum(num_encoder_layers)==15, [15, B] - offset = 0 - for num_layer in num_encoder_layers: - states.append(len_cache[offset : offset + num_layer]) - offset += num_layer - - avg_cache = avg_cache.transpose(0, 1) # [15, B, 384] - offset = 0 - for num_layer in num_encoder_layers: - states.append(avg_cache[offset : offset + num_layer]) - offset += num_layer - - attn_cache = attn_cache.transpose(0, 2) # [15*3, 64, B, 192] - left_context_len = attn_cache.shape[1] - offset = 0 - for i, num_layer in enumerate(num_encoder_layers): - ds = self.model.zipformer_downsampling_factors[i] - states.append( - attn_cache[offset : offset + num_layer, : left_context_len // ds] - ) - offset += num_layer - for i, num_layer in enumerate(num_encoder_layers): - encoder_attention_dim = encoder_attention_dims[i] - ds = self.model.zipformer_downsampling_factors[i] - states.append( - attn_cache[ - offset : offset + num_layer, - : left_context_len // ds, - :, - : encoder_attention_dim // 2, - ] - ) - offset += num_layer - for i, num_layer in enumerate(num_encoder_layers): - ds = self.model.zipformer_downsampling_factors[i] - states.append( - attn_cache[ - offset : offset + num_layer, - : left_context_len // ds, - :, - : encoder_attention_dim // 2, - ] - ) - offset += num_layer - - cnn_cache = cnn_cache.transpose(0, 1) # [30, B, 384, cnn_kernel-1] - offset = 0 - for num_layer in num_encoder_layers: - states.append(cnn_cache[offset : offset + num_layer]) - offset += num_layer - for num_layer in num_encoder_layers: - states.append(cnn_cache[offset : offset + num_layer]) - offset += num_layer - - encoder_out, encoder_out_lens, new_states = self.model.streaming_forward( - x=x, - x_lens=x_lens, - states=states, - ) - - new_len_cache = torch.cat(states[: self.model.num_encoders]).transpose( - 0, 1 - ) # [B,15] - new_avg_cache = torch.cat( - states[self.model.num_encoders : 2 * self.model.num_encoders] - ).transpose( - 0, 1 - ) # [B,15,384] - new_cnn_cache = torch.cat(states[5 * self.model.num_encoders :]).transpose( - 0, 1 - ) # [B,2*15,384,cnn_kernel-1] - assert len(set(encoder_attention_dims)) == 1 - pad_tensors = [ - torch.nn.functional.pad( - tensor, - ( - 0, - encoder_attention_dims[0] - tensor.shape[-1], - 0, - 0, - 0, - left_context_len - tensor.shape[1], - 0, - 0, - ), - ) - for tensor in states[ - 2 * self.model.num_encoders : 5 * self.model.num_encoders - ] - ] - new_attn_cache = torch.cat(pad_tensors).transpose(0, 2) # [B,64,15*3,192] - - return ( - encoder_out, - encoder_out_lens, - new_len_cache, - new_avg_cache, - new_attn_cache, - new_cnn_cache, - ) - - -class TritonOnnxDecoder(torch.nn.Module): - """This class warps the Decoder in decoder.py - to remove the scalar input "need_pad". - Triton currently doesn't support scalar input. - https://github.com/triton-inference-server/server/issues/2333 - """ - - def __init__( - self, - decoder: torch.nn.Module, - ): - """ - Args: - decoder: A instance of Decoder - """ - super().__init__() - self.model = decoder - - def forward(self, y: torch.Tensor) -> torch.Tensor: - """ - Args: - y: - A 2-D tensor of shape (N, U). - Returns: - Return a tensor of shape (N, U, decoder_dim). - """ - # False to not pad the input. Should be False during inference. - need_pad = False - return self.model(y, need_pad) - - -class TritonOnnxJoiner(torch.nn.Module): - """This class warps the Joiner in joiner.py - to remove the scalar input "project_input". - Triton currently doesn't support scalar input. - https://github.com/triton-inference-server/server/issues/2333 - "project_input" is set to True. - Triton solutions only need export joiner to a single joiner.onnx. - """ - - def __init__( - self, - joiner: torch.nn.Module, - ): - super().__init__() - self.model = joiner - - def forward( - self, - encoder_out: torch.Tensor, - decoder_out: torch.Tensor, - ) -> torch.Tensor: - """ - Args: - encoder_out: - Output from the encoder. Its shape is (N, T, s_range, C). - decoder_out: - Output from the decoder. Its shape is (N, T, s_range, C). - Returns: - Return a tensor of shape (N, T, s_range, C). - """ - # Apply input projections encoder_proj and decoder_proj. - project_input = True - return self.model(encoder_out, decoder_out, project_input) diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py new file mode 120000 index 000000000..c8548d459 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py deleted file mode 100755 index 8192e01fd..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py +++ /dev/null @@ -1,512 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) - -""" -This script loads ONNX models exported by ./export-onnx.py -and uses them to decode waves. - -We use the pre-trained model from -https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 -as an example to show how to use this file. - -1. Download the pre-trained model - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" -git lfs pull --include "exp/pretrained.pt" -cd exp -ln -s pretrained.pt epoch-99.pt -popd - -2. Export the model to ONNX - -./pruned_transducer_stateless7_streaming/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --use-averaged-model 0 \ - --epoch 99 \ - --avg 1 \ - --decode-chunk-len 32 \ - --exp-dir $repo/exp/ - -It will generate the following 3 files in $repo/exp - - - encoder-epoch-99-avg-1.onnx - - decoder-epoch-99-avg-1.onnx - - joiner-epoch-99-avg-1.onnx - -3. Run this file with the exported ONNX models - -./pruned_transducer_stateless7_streaming/onnx_pretrained.py \ - --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ - --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ - --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav - -Note: Even though this script only supports decoding a single file, -the exported ONNX models do support batch processing. -""" - -import argparse -import logging -from typing import Dict, List, Optional, Tuple - -import k2 -import numpy as np -import onnxruntime as ort -import torch -import torchaudio -from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--encoder-model-filename", - type=str, - required=True, - help="Path to the encoder onnx model. ", - ) - - parser.add_argument( - "--decoder-model-filename", - type=str, - required=True, - help="Path to the decoder onnx model. ", - ) - - parser.add_argument( - "--joiner-model-filename", - type=str, - required=True, - help="Path to the joiner onnx model. ", - ) - - parser.add_argument( - "--tokens", - type=str, - help="""Path to tokens.txt.""", - ) - - parser.add_argument( - "sound_file", - type=str, - help="The input sound file to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - return parser - - -class OnnxModel: - def __init__( - self, - encoder_model_filename: str, - decoder_model_filename: str, - joiner_model_filename: str, - ): - session_opts = ort.SessionOptions() - session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 1 - - self.session_opts = session_opts - - self.init_encoder(encoder_model_filename) - self.init_decoder(decoder_model_filename) - self.init_joiner(joiner_model_filename) - - def init_encoder(self, encoder_model_filename: str): - self.encoder = ort.InferenceSession( - encoder_model_filename, - sess_options=self.session_opts, - ) - self.init_encoder_states() - - def init_encoder_states(self, batch_size: int = 1): - encoder_meta = self.encoder.get_modelmeta().custom_metadata_map - - model_type = encoder_meta["model_type"] - assert model_type == "zipformer", model_type - - decode_chunk_len = int(encoder_meta["decode_chunk_len"]) - T = int(encoder_meta["T"]) - - num_encoder_layers = encoder_meta["num_encoder_layers"] - encoder_dims = encoder_meta["encoder_dims"] - attention_dims = encoder_meta["attention_dims"] - cnn_module_kernels = encoder_meta["cnn_module_kernels"] - left_context_len = encoder_meta["left_context_len"] - - def to_int_list(s): - return list(map(int, s.split(","))) - - num_encoder_layers = to_int_list(num_encoder_layers) - encoder_dims = to_int_list(encoder_dims) - attention_dims = to_int_list(attention_dims) - cnn_module_kernels = to_int_list(cnn_module_kernels) - left_context_len = to_int_list(left_context_len) - - logging.info(f"decode_chunk_len: {decode_chunk_len}") - logging.info(f"T: {T}") - logging.info(f"num_encoder_layers: {num_encoder_layers}") - logging.info(f"encoder_dims: {encoder_dims}") - logging.info(f"attention_dims: {attention_dims}") - logging.info(f"cnn_module_kernels: {cnn_module_kernels}") - logging.info(f"left_context_len: {left_context_len}") - - num_encoders = len(num_encoder_layers) - - cached_len = [] - cached_avg = [] - cached_key = [] - cached_val = [] - cached_val2 = [] - cached_conv1 = [] - cached_conv2 = [] - - N = batch_size - - for i in range(num_encoders): - cached_len.append(torch.zeros(num_encoder_layers[i], N, dtype=torch.int64)) - cached_avg.append(torch.zeros(num_encoder_layers[i], N, encoder_dims[i])) - cached_key.append( - torch.zeros( - num_encoder_layers[i], left_context_len[i], N, attention_dims[i] - ) - ) - cached_val.append( - torch.zeros( - num_encoder_layers[i], - left_context_len[i], - N, - attention_dims[i] // 2, - ) - ) - cached_val2.append( - torch.zeros( - num_encoder_layers[i], - left_context_len[i], - N, - attention_dims[i] // 2, - ) - ) - cached_conv1.append( - torch.zeros( - num_encoder_layers[i], N, encoder_dims[i], cnn_module_kernels[i] - 1 - ) - ) - cached_conv2.append( - torch.zeros( - num_encoder_layers[i], N, encoder_dims[i], cnn_module_kernels[i] - 1 - ) - ) - - self.cached_len = cached_len - self.cached_avg = cached_avg - self.cached_key = cached_key - self.cached_val = cached_val - self.cached_val2 = cached_val2 - self.cached_conv1 = cached_conv1 - self.cached_conv2 = cached_conv2 - - self.num_encoders = num_encoders - - self.segment = T - self.offset = decode_chunk_len - - def init_decoder(self, decoder_model_filename: str): - self.decoder = ort.InferenceSession( - decoder_model_filename, - sess_options=self.session_opts, - ) - - decoder_meta = self.decoder.get_modelmeta().custom_metadata_map - self.context_size = int(decoder_meta["context_size"]) - self.vocab_size = int(decoder_meta["vocab_size"]) - - logging.info(f"context_size: {self.context_size}") - logging.info(f"vocab_size: {self.vocab_size}") - - def init_joiner(self, joiner_model_filename: str): - self.joiner = ort.InferenceSession( - joiner_model_filename, - sess_options=self.session_opts, - ) - - joiner_meta = self.joiner.get_modelmeta().custom_metadata_map - self.joiner_dim = int(joiner_meta["joiner_dim"]) - - logging.info(f"joiner_dim: {self.joiner_dim}") - - def _build_encoder_input_output( - self, - x: torch.Tensor, - ) -> Tuple[Dict[str, np.ndarray], List[str]]: - encoder_input = {"x": x.numpy()} - encoder_output = ["encoder_out"] - - def build_states_input(states: List[torch.Tensor], name: str): - for i, s in enumerate(states): - if isinstance(s, torch.Tensor): - encoder_input[f"{name}_{i}"] = s.numpy() - else: - encoder_input[f"{name}_{i}"] = s - - encoder_output.append(f"new_{name}_{i}") - - build_states_input(self.cached_len, "cached_len") - build_states_input(self.cached_avg, "cached_avg") - build_states_input(self.cached_key, "cached_key") - build_states_input(self.cached_val, "cached_val") - build_states_input(self.cached_val2, "cached_val2") - build_states_input(self.cached_conv1, "cached_conv1") - build_states_input(self.cached_conv2, "cached_conv2") - - return encoder_input, encoder_output - - def _update_states(self, states: List[np.ndarray]): - num_encoders = self.num_encoders - - self.cached_len = states[num_encoders * 0 : num_encoders * 1] - self.cached_avg = states[num_encoders * 1 : num_encoders * 2] - self.cached_key = states[num_encoders * 2 : num_encoders * 3] - self.cached_val = states[num_encoders * 3 : num_encoders * 4] - self.cached_val2 = states[num_encoders * 4 : num_encoders * 5] - self.cached_conv1 = states[num_encoders * 5 : num_encoders * 6] - self.cached_conv2 = states[num_encoders * 6 : num_encoders * 7] - - def run_encoder(self, x: torch.Tensor) -> torch.Tensor: - """ - Args: - x: - A 3-D tensor of shape (N, T, C) - Returns: - Return a 3-D tensor of shape (N, T', joiner_dim) where - T' is usually equal to ((T-7)//2+1)//2 - """ - encoder_input, encoder_output_names = self._build_encoder_input_output(x) - out = self.encoder.run(encoder_output_names, encoder_input) - - self._update_states(out[1:]) - - return torch.from_numpy(out[0]) - - def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor: - """ - Args: - decoder_input: - A 2-D tensor of shape (N, context_size) - Returns: - Return a 2-D tensor of shape (N, joiner_dim) - """ - out = self.decoder.run( - [self.decoder.get_outputs()[0].name], - {self.decoder.get_inputs()[0].name: decoder_input.numpy()}, - )[0] - - return torch.from_numpy(out) - - def run_joiner( - self, encoder_out: torch.Tensor, decoder_out: torch.Tensor - ) -> torch.Tensor: - """ - Args: - encoder_out: - A 2-D tensor of shape (N, joiner_dim) - decoder_out: - A 2-D tensor of shape (N, joiner_dim) - Returns: - Return a 2-D tensor of shape (N, vocab_size) - """ - out = self.joiner.run( - [self.joiner.get_outputs()[0].name], - { - self.joiner.get_inputs()[0].name: encoder_out.numpy(), - self.joiner.get_inputs()[1].name: decoder_out.numpy(), - }, - )[0] - - return torch.from_numpy(out) - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0].contiguous()) - return ans - - -def create_streaming_feature_extractor() -> OnlineFeature: - """Create a CPU streaming feature extractor. - - At present, we assume it returns a fbank feature extractor with - fixed options. In the future, we will support passing in the options - from outside. - - Returns: - Return a CPU streaming feature extractor. - """ - opts = FbankOptions() - opts.device = "cpu" - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = 16000 - opts.mel_opts.num_bins = 80 - return OnlineFbank(opts) - - -def greedy_search( - model: OnnxModel, - encoder_out: torch.Tensor, - context_size: int, - decoder_out: Optional[torch.Tensor] = None, - hyp: Optional[List[int]] = None, -) -> List[int]: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - Args: - model: - The transducer model. - encoder_out: - A 3-D tensor of shape (1, T, joiner_dim) - context_size: - The context size of the decoder model. - decoder_out: - Optional. Decoder output of the previous chunk. - hyp: - Decoding results for previous chunks. - Returns: - Return the decoded results so far. - """ - - blank_id = 0 - - if decoder_out is None: - assert hyp is None, hyp - hyp = [blank_id] * context_size - decoder_input = torch.tensor([hyp], dtype=torch.int64) - decoder_out = model.run_decoder(decoder_input) - else: - assert hyp is not None, hyp - - encoder_out = encoder_out.squeeze(0) - T = encoder_out.size(0) - for t in range(T): - cur_encoder_out = encoder_out[t : t + 1] - joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0) - y = joiner_out.argmax(dim=0).item() - if y != blank_id: - hyp.append(y) - decoder_input = hyp[-context_size:] - decoder_input = torch.tensor([decoder_input], dtype=torch.int64) - decoder_out = model.run_decoder(decoder_input) - - return hyp, decoder_out - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - - model = OnnxModel( - encoder_model_filename=args.encoder_model_filename, - decoder_model_filename=args.decoder_model_filename, - joiner_model_filename=args.joiner_model_filename, - ) - - sample_rate = 16000 - - logging.info("Constructing Fbank computer") - online_fbank = create_streaming_feature_extractor() - - logging.info(f"Reading sound files: {args.sound_file}") - waves = read_sound_files( - filenames=[args.sound_file], - expected_sample_rate=sample_rate, - )[0] - - tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32) - wave_samples = torch.cat([waves, tail_padding]) - - num_processed_frames = 0 - segment = model.segment - offset = model.offset - - context_size = model.context_size - hyp = None - decoder_out = None - - chunk = int(1 * sample_rate) # 1 second - start = 0 - while start < wave_samples.numel(): - end = min(start + chunk, wave_samples.numel()) - samples = wave_samples[start:end] - start += chunk - - online_fbank.accept_waveform( - sampling_rate=sample_rate, - waveform=samples, - ) - - while online_fbank.num_frames_ready - num_processed_frames >= segment: - frames = [] - for i in range(segment): - frames.append(online_fbank.get_frame(num_processed_frames + i)) - num_processed_frames += offset - frames = torch.cat(frames, dim=0) - frames = frames.unsqueeze(0) - encoder_out = model.run_encoder(frames) - hyp, decoder_out = greedy_search( - model, - encoder_out, - context_size, - decoder_out, - hyp, - ) - - symbol_table = k2.SymbolTable.from_file(args.tokens) - - text = "" - for i in hyp[context_size:]: - text += symbol_table[i] - text = text.replace("▁", " ").strip() - - logging.info(args.sound_file) - logging.info(text) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py new file mode 120000 index 000000000..ae4d9bb04 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/pretrained.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/pretrained.py deleted file mode 100755 index fb77fdd42..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/pretrained.py +++ /dev/null @@ -1,355 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads a checkpoint and uses it to decode waves. -You can generate the checkpoint with the following command: - -./pruned_transducer_stateless7_streaming/export.py \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 20 \ - --avg 10 - -Usage of this script: - -(1) greedy search -./pruned_transducer_stateless7_streaming/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) beam search -./pruned_transducer_stateless7_streaming/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) modified beam search -./pruned_transducer_stateless7_streaming/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method modified_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(4) fast beam search -./pruned_transducer_stateless7_streaming/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method fast_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -You can also use `./pruned_transducer_stateless7_streaming/exp/epoch-xx.pt`. - -Note: ./pruned_transducer_stateless7_streaming/exp/pretrained.pt is generated by -./pruned_transducer_stateless7_streaming/export.py -""" - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import sentencepiece as spm -import torch -import torchaudio -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.utils import str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--bpe-model", - type=str, - help="""Path to bpe.model.""", - ) - - parser.add_argument( - "--method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. Used only when - --method is greedy_search. - """, - ) - - add_model_arguments(parser) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - - params.update(vars(args)) - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - model.device = device - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) - - num_waves = encoder_out.size(0) - hyps = [] - msg = f"Using {params.method}" - if params.method == "beam_search": - msg += f" with beam size {params.beam_size}" - logging.info(msg) - - if params.method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - else: - for i in range(num_waves): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError(f"Unsupported method: {params.method}") - - hyps.append(sp.decode(hyp).split()) - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/pretrained.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/pretrained.py new file mode 120000 index 000000000..9510b8fde --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py deleted file mode 100755 index 5a36b695f..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py +++ /dev/null @@ -1,419 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -./pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \ - --tokens ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/tokens.txt \ - --encoder-param-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/encoder_jit_trace-pnnx.ncnn.param \ - --encoder-bin-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/encoder_jit_trace-pnnx.ncnn.bin \ - --decoder-param-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/decoder_jit_trace-pnnx.ncnn.param \ - --decoder-bin-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/decoder_jit_trace-pnnx.ncnn.bin \ - --joiner-param-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/joiner_jit_trace-pnnx.ncnn.param \ - --joiner-bin-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/joiner_jit_trace-pnnx.ncnn.bin \ - ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/test_wavs/1089-134686-0001.wav - -You can find pretrained models at -- English: https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-en-2023-02-13 -- Bilingual (Chinese + English): https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-bilingual-zh-en-2023-02-13 -""" - -import argparse -import logging -from typing import List, Optional, Tuple - -import k2 -import ncnn -import torch -import torchaudio -from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--tokens", - type=str, - help="Path to tokens.txt", - ) - - parser.add_argument( - "--encoder-param-filename", - type=str, - help="Path to encoder.ncnn.param", - ) - - parser.add_argument( - "--encoder-bin-filename", - type=str, - help="Path to encoder.ncnn.bin", - ) - - parser.add_argument( - "--decoder-param-filename", - type=str, - help="Path to decoder.ncnn.param", - ) - - parser.add_argument( - "--decoder-bin-filename", - type=str, - help="Path to decoder.ncnn.bin", - ) - - parser.add_argument( - "--joiner-param-filename", - type=str, - help="Path to joiner.ncnn.param", - ) - - parser.add_argument( - "--joiner-bin-filename", - type=str, - help="Path to joiner.ncnn.bin", - ) - - parser.add_argument( - "sound_filename", - type=str, - help="Path to foo.wav", - ) - - return parser.parse_args() - - -def to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - -class Model: - def __init__(self, args): - self.init_encoder(args) - self.init_decoder(args) - self.init_joiner(args) - - # Please change the parameters according to your model - self.num_encoder_layers = to_int_tuple("2,4,3,2,4") - self.encoder_dims = to_int_tuple("384,384,384,384,384") # also known as d_model - self.attention_dims = to_int_tuple("192,192,192,192,192") - self.zipformer_downsampling_factors = to_int_tuple("1,2,4,8,2") - self.cnn_module_kernels = to_int_tuple("31,31,31,31,31") - - self.decode_chunk_size = 32 // 2 - num_left_chunks = 4 - self.left_context_length = self.decode_chunk_size * num_left_chunks # 64 - - self.chunk_length = self.decode_chunk_size * 2 - pad_length = 7 - self.T = self.chunk_length + pad_length - - def get_init_states(self) -> List[torch.Tensor]: - cached_len_list = [] - cached_avg_list = [] - cached_key_list = [] - cached_val_list = [] - cached_val2_list = [] - cached_conv1_list = [] - cached_conv2_list = [] - - for i in range(len(self.num_encoder_layers)): - num_layers = self.num_encoder_layers[i] - ds = self.zipformer_downsampling_factors[i] - attention_dim = self.attention_dims[i] - left_context_length = self.left_context_length // ds - encoder_dim = self.encoder_dims[i] - cnn_module_kernel = self.cnn_module_kernels[i] - - cached_len_list.append(torch.zeros(num_layers)) - cached_avg_list.append(torch.zeros(num_layers, encoder_dim)) - cached_key_list.append( - torch.zeros(num_layers, left_context_length, attention_dim) - ) - cached_val_list.append( - torch.zeros(num_layers, left_context_length, attention_dim // 2) - ) - cached_val2_list.append( - torch.zeros(num_layers, left_context_length, attention_dim // 2) - ) - cached_conv1_list.append( - torch.zeros(num_layers, encoder_dim, cnn_module_kernel - 1) - ) - cached_conv2_list.append( - torch.zeros(num_layers, encoder_dim, cnn_module_kernel - 1) - ) - - states = ( - cached_len_list - + cached_avg_list - + cached_key_list - + cached_val_list - + cached_val2_list - + cached_conv1_list - + cached_conv2_list - ) - - return states - - def init_encoder(self, args): - encoder_net = ncnn.Net() - encoder_net.opt.use_packing_layout = False - encoder_net.opt.use_fp16_storage = False - encoder_net.opt.num_threads = 4 - - encoder_param = args.encoder_param_filename - encoder_model = args.encoder_bin_filename - - encoder_net.load_param(encoder_param) - encoder_net.load_model(encoder_model) - - self.encoder_net = encoder_net - - def init_decoder(self, args): - decoder_param = args.decoder_param_filename - decoder_model = args.decoder_bin_filename - - decoder_net = ncnn.Net() - decoder_net.opt.num_threads = 4 - - decoder_net.load_param(decoder_param) - decoder_net.load_model(decoder_model) - - self.decoder_net = decoder_net - - def init_joiner(self, args): - joiner_param = args.joiner_param_filename - joiner_model = args.joiner_bin_filename - joiner_net = ncnn.Net() - joiner_net.opt.num_threads = 4 - - joiner_net.load_param(joiner_param) - joiner_net.load_model(joiner_model) - - self.joiner_net = joiner_net - - def run_encoder( - self, - x: torch.Tensor, - states: List[torch.Tensor], - ) -> Tuple[torch.Tensor, List[torch.Tensor]]: - """ - Args: - x: - A tensor of shape (T, C) - states: - A list of tensors. len(states) == self.num_layers * 4 - Returns: - Return a tuple containing: - - encoder_out, a tensor of shape (T, encoder_dim). - - next_states, a list of tensors containing the next states - """ - with self.encoder_net.create_extractor() as ex: - ex.input("in0", ncnn.Mat(x.numpy()).clone()) - - for i in range(len(states)): - name = f"in{i+1}" - ex.input(name, ncnn.Mat(states[i].squeeze().numpy()).clone()) - - ret, ncnn_out0 = ex.extract("out0") - assert ret == 0, ret - encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() - - out_states: List[torch.Tensor] = [] - for i in range(len(states)): - name = f"out{i+1}" - ret, ncnn_out_state = ex.extract(name) - assert ret == 0, ret - ncnn_out_state = torch.from_numpy(ncnn_out_state.numpy()) - - if i < len(self.num_encoder_layers): - # for cached_len, we need to discard the last dim - ncnn_out_state = ncnn_out_state.squeeze(1) - - out_states.append(ncnn_out_state) - - return encoder_out, out_states - - def run_decoder(self, decoder_input): - assert decoder_input.dtype == torch.int32 - - with self.decoder_net.create_extractor() as ex: - ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone()) - ret, ncnn_out0 = ex.extract("out0") - assert ret == 0, ret - decoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() - return decoder_out - - def run_joiner(self, encoder_out, decoder_out): - with self.joiner_net.create_extractor() as ex: - ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone()) - ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone()) - ret, ncnn_out0 = ex.extract("out0") - assert ret == 0, ret - joiner_out = torch.from_numpy(ncnn_out0.numpy()).clone() - return joiner_out - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -def create_streaming_feature_extractor() -> OnlineFeature: - """Create a CPU streaming feature extractor. - - At present, we assume it returns a fbank feature extractor with - fixed options. In the future, we will support passing in the options - from outside. - - Returns: - Return a CPU streaming feature extractor. - """ - opts = FbankOptions() - opts.device = "cpu" - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = 16000 - opts.mel_opts.num_bins = 80 - return OnlineFbank(opts) - - -def greedy_search( - model: Model, - encoder_out: torch.Tensor, - decoder_out: Optional[torch.Tensor] = None, - hyp: Optional[List[int]] = None, -): - context_size = 2 - blank_id = 0 - - if decoder_out is None: - assert hyp is None, hyp - hyp = [blank_id] * context_size - decoder_input = torch.tensor(hyp, dtype=torch.int32) # (1, context_size) - decoder_out = model.run_decoder(decoder_input).squeeze(0) - else: - assert decoder_out.ndim == 1 - assert hyp is not None, hyp - - T = encoder_out.size(0) - for t in range(T): - cur_encoder_out = encoder_out[t] - - joiner_out = model.run_joiner(cur_encoder_out, decoder_out) - y = joiner_out.argmax(dim=0).item() - if y != blank_id: - hyp.append(y) - decoder_input = hyp[-context_size:] - decoder_input = torch.tensor(decoder_input, dtype=torch.int32) - decoder_out = model.run_decoder(decoder_input).squeeze(0) - - return hyp, decoder_out - - -def main(): - args = get_args() - logging.info(vars(args)) - - model = Model(args) - - sound_file = args.sound_filename - - sample_rate = 16000 - - logging.info("Constructing Fbank computer") - online_fbank = create_streaming_feature_extractor() - - logging.info(f"Reading sound files: {sound_file}") - wave_samples = read_sound_files( - filenames=[sound_file], - expected_sample_rate=sample_rate, - )[0] - logging.info(wave_samples.shape) - - tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32) - - wave_samples = torch.cat([wave_samples, tail_padding]) - - states = model.get_init_states() - logging.info(f"number of states: {len(states)}") - - hyp = None - decoder_out = None - - num_processed_frames = 0 - segment = model.T - offset = model.chunk_length - - chunk = int(1 * sample_rate) # 0.2 second - - start = 0 - while start < wave_samples.numel(): - end = min(start + chunk, wave_samples.numel()) - samples = wave_samples[start:end] - start += chunk - - online_fbank.accept_waveform( - sampling_rate=sample_rate, - waveform=samples, - ) - while online_fbank.num_frames_ready - num_processed_frames >= segment: - frames = [] - for i in range(segment): - frames.append(online_fbank.get_frame(num_processed_frames + i)) - num_processed_frames += offset - frames = torch.cat(frames, dim=0) - encoder_out, states = model.run_encoder(frames, states) - hyp, decoder_out = greedy_search(model, encoder_out, decoder_out, hyp) - - symbol_table = k2.SymbolTable.from_file(args.tokens) - - context_size = 2 - text = "" - for i in hyp[context_size:]: - text += symbol_table[i] - text = text.replace("▁", " ").strip() - - logging.info(sound_file) - logging.info(text) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py new file mode 120000 index 000000000..92c3904af --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer.py deleted file mode 100644 index a5c422959..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer.py +++ /dev/null @@ -1,2891 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey,) -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -import itertools -import logging -import math -import random -import warnings -from typing import List, Optional, Tuple, Union - -import torch -from encoder_interface import EncoderInterface -from scaling import ( - ScaledLinear, # not as in other dirs.. just scales down initial parameter values. -) -from scaling import ( - ActivationBalancer, - BasicNorm, - DoubleSwish, - Identity, - MaxEig, - ScaledConv1d, - Whiten, - _diag, - penalize_abs_values_gt, - random_clamp, - softmax, -) -from torch import Tensor, nn - -from icefall.utils import make_pad_mask, subsequent_chunk_mask - - -def stack_states(state_list: List[List[Tensor]]) -> List[Tensor]: - """Stack list of zipformer states that correspond to separate utterances - into a single emformer state, so that it can be used as an input for - zipformer when those utterances are formed into a batch. - - Note: - It is the inverse of :func:`unstack_states`. - - Args: - state_list: - Each element in state_list corresponding to the internal state - of the zipformer model for a single utterance. - ``states[i]`` is a list of 7 * num_encoders elements of i-th utterance. - ``states[i][0:num_encoders]`` is the cached numbers of past frames. - ``states[i][num_encoders:2*num_encoders]`` is the cached average tensors. - ``states[i][2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. - ``states[i][3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. - ``states[i][4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. - ``states[i][5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. - ``states[i][6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. - - Returns: - A new state corresponding to a batch of utterances. - See the input argument of :func:`unstack_states` for the meaning - of the returned tensor. - """ - batch_size = len(state_list) - assert len(state_list[0]) % 7 == 0, len(state_list[0]) - num_encoders = len(state_list[0]) // 7 - - cached_len = [] - cached_avg = [] - cached_key = [] - cached_val = [] - cached_val2 = [] - cached_conv1 = [] - cached_conv2 = [] - - # For cached_len - len_list = [state_list[n][0:num_encoders] for n in range(batch_size)] - for i in range(num_encoders): - # len_avg: (num_layers, batch_size) - len_avg = torch.cat([len_list[n][i] for n in range(batch_size)], dim=1) - cached_len.append(len_avg) - - # For cached_avg - avg_list = [ - state_list[n][num_encoders : 2 * num_encoders] for n in range(batch_size) - ] - for i in range(num_encoders): - # avg: (num_layers, batch_size, D) - avg = torch.cat([avg_list[n][i] for n in range(batch_size)], dim=1) - cached_avg.append(avg) - - # For cached_key - key_list = [ - state_list[n][2 * num_encoders : 3 * num_encoders] for n in range(batch_size) - ] - for i in range(num_encoders): - # key: (num_layers, left_context_size, batch_size, D) - key = torch.cat([key_list[n][i] for n in range(batch_size)], dim=2) - cached_key.append(key) - - # For cached_val - val_list = [ - state_list[n][3 * num_encoders : 4 * num_encoders] for n in range(batch_size) - ] - for i in range(num_encoders): - # val: (num_layers, left_context_size, batch_size, D) - val = torch.cat([val_list[n][i] for n in range(batch_size)], dim=2) - cached_val.append(val) - - # For cached_val2 - val2_list = [ - state_list[n][4 * num_encoders : 5 * num_encoders] for n in range(batch_size) - ] - for i in range(num_encoders): - # val2: (num_layers, left_context_size, batch_size, D) - val2 = torch.cat([val2_list[n][i] for n in range(batch_size)], dim=2) - cached_val2.append(val2) - - # For cached_conv1 - conv1_list = [ - state_list[n][5 * num_encoders : 6 * num_encoders] for n in range(batch_size) - ] - for i in range(num_encoders): - # conv1: (num_layers, batch_size, D, kernel-1) - conv1 = torch.cat([conv1_list[n][i] for n in range(batch_size)], dim=1) - cached_conv1.append(conv1) - - # For cached_conv2 - conv2_list = [ - state_list[n][6 * num_encoders : 7 * num_encoders] for n in range(batch_size) - ] - for i in range(num_encoders): - # conv2: (num_layers, batch_size, D, kernel-1) - conv2 = torch.cat([conv2_list[n][i] for n in range(batch_size)], dim=1) - cached_conv2.append(conv2) - - states = ( - cached_len - + cached_avg - + cached_key - + cached_val - + cached_val2 - + cached_conv1 - + cached_conv2 - ) - return states - - -def unstack_states(states: List[Tensor]) -> List[List[Tensor]]: - """Unstack the zipformer state corresponding to a batch of utterances - into a list of states, where the i-th entry is the state from the i-th - utterance in the batch. - - Note: - It is the inverse of :func:`stack_states`. - - Args: - states: - A list of 7 * num_encoders elements: - ``states[0:num_encoders]`` is the cached numbers of past frames. - ``states[num_encoders:2*num_encoders]`` is the cached average tensors. - ``states[2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. - ``states[3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. - ``states[4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. - ``states[5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. - ``states[6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. - - Returns: - A list of states. - ``states[i]`` is a list of 7 * num_encoders elements of i-th utterance. - """ - assert len(states) % 7 == 0, len(states) - num_encoders = len(states) // 7 - ( - cached_len, - cached_avg, - cached_key, - cached_val, - cached_val2, - cached_conv1, - cached_conv2, - ) = (states[i * num_encoders : (i + 1) * num_encoders] for i in range(7)) - - batch_size = cached_len[0].shape[1] - - len_list = [[] for _ in range(batch_size)] - for i in range(num_encoders): - # cached_len[i]: (num_layers, batch_size) - len_avg = cached_len[i].chunk(chunks=batch_size, dim=1) - for n in range(batch_size): - len_list[n].append(len_avg[n]) - - avg_list = [[] for _ in range(batch_size)] - for i in range(num_encoders): - # cached_avg[i]: (num_layers, batch_size, D) - avg = cached_avg[i].chunk(chunks=batch_size, dim=1) - for n in range(batch_size): - avg_list[n].append(avg[n]) - - key_list = [[] for _ in range(batch_size)] - for i in range(num_encoders): - # cached_key[i]: (num_layers, left_context, batch_size, D) - key = cached_key[i].chunk(chunks=batch_size, dim=2) - for n in range(batch_size): - key_list[n].append(key[n]) - - val_list = [[] for _ in range(batch_size)] - for i in range(num_encoders): - # cached_val[i]: (num_layers, left_context, batch_size, D) - val = cached_val[i].chunk(chunks=batch_size, dim=2) - for n in range(batch_size): - val_list[n].append(val[n]) - - val2_list = [[] for _ in range(batch_size)] - for i in range(num_encoders): - # cached_val2[i]: (num_layers, left_context, batch_size, D) - val2 = cached_val2[i].chunk(chunks=batch_size, dim=2) - for n in range(batch_size): - val2_list[n].append(val2[n]) - - conv1_list = [[] for _ in range(batch_size)] - for i in range(num_encoders): - # cached_conv1[i]: (num_layers, batch_size, D, kernel-1) - conv1 = cached_conv1[i].chunk(chunks=batch_size, dim=1) - for n in range(batch_size): - conv1_list[n].append(conv1[n]) - - conv2_list = [[] for _ in range(batch_size)] - for i in range(num_encoders): - # cached_conv2[i]: (num_layers, batch_size, D, kernel-1) - conv2 = cached_conv2[i].chunk(chunks=batch_size, dim=1) - for n in range(batch_size): - conv2_list[n].append(conv2[n]) - - state_list = [ - ( - len_list[i] - + avg_list[i] - + key_list[i] - + val_list[i] - + val2_list[i] - + conv1_list[i] - + conv2_list[i] - ) - for i in range(batch_size) - ] - return state_list - - -class Zipformer(EncoderInterface): - """ - Args: - num_features (int): Number of input features - d_model: (int,int): embedding dimension of 2 encoder stacks - attention_dim: (int,int): attention dimension of 2 encoder stacks - nhead (int, int): number of heads - dim_feedforward (int, int): feedforward dimension in 2 encoder stacks - num_encoder_layers (int): number of encoder layers - dropout (float): dropout rate - cnn_module_kernels (int): Kernel size of convolution module - warmup_batches (float): number of batches to warm up over - """ - - def __init__( - self, - num_features: int, - output_downsampling_factor: int = 2, - encoder_dims: Tuple[int] = (384, 384), - attention_dim: Tuple[int] = (256, 256), - encoder_unmasked_dims: Tuple[int] = (256, 256), - zipformer_downsampling_factors: Tuple[int] = (2, 4), - nhead: Tuple[int] = (8, 8), - feedforward_dim: Tuple[int] = (1536, 2048), - num_encoder_layers: Tuple[int] = (12, 12), - dropout: float = 0.1, - cnn_module_kernels: Tuple[int] = (31, 31), - pos_dim: int = 4, - num_left_chunks: int = 4, - short_chunk_threshold: float = 0.75, - short_chunk_size: int = 50, - decode_chunk_size: int = 16, - warmup_batches: float = 4000.0, - ) -> None: - super(Zipformer, self).__init__() - - self.num_features = num_features - assert 0 < encoder_dims[0] <= encoder_dims[1] - self.encoder_dims = encoder_dims - self.encoder_unmasked_dims = encoder_unmasked_dims - self.zipformer_downsampling_factors = zipformer_downsampling_factors - self.output_downsampling_factor = output_downsampling_factor - - self.num_left_chunks = num_left_chunks - self.short_chunk_threshold = short_chunk_threshold - self.short_chunk_size = short_chunk_size - - # Used in decoding - self.decode_chunk_size = decode_chunk_size - - self.left_context_len = self.decode_chunk_size * self.num_left_chunks - - # will be written to, see set_batch_count() - self.batch_count = 0 - self.warmup_end = warmup_batches - - for u, d in zip(encoder_unmasked_dims, encoder_dims): - assert u <= d, (u, d) - - # self.encoder_embed converts the input of shape (N, T, num_features) - # to the shape (N, (T - 7)//2, encoder_dims). - # That is, it does two things simultaneously: - # (1) subsampling: T -> (T - 7)//2 - # (2) embedding: num_features -> encoder_dims - self.encoder_embed = Conv2dSubsampling( - num_features, encoder_dims[0], dropout=dropout - ) - - # each one will be ZipformerEncoder or DownsampledZipformerEncoder - encoders = [] - - self.num_encoder_layers = num_encoder_layers - self.num_encoders = len(encoder_dims) - self.attention_dims = attention_dim - self.cnn_module_kernels = cnn_module_kernels - for i in range(self.num_encoders): - encoder_layer = ZipformerEncoderLayer( - encoder_dims[i], - attention_dim[i], - nhead[i], - feedforward_dim[i], - dropout, - cnn_module_kernels[i], - pos_dim, - ) - - # For the segment of the warmup period, we let the Conv2dSubsampling - # layer learn something. Then we start to warm up the other encoders. - encoder = ZipformerEncoder( - encoder_layer, - num_encoder_layers[i], - dropout, - warmup_begin=warmup_batches * (i + 1) / (self.num_encoders + 1), - warmup_end=warmup_batches * (i + 2) / (self.num_encoders + 1), - ) - - if zipformer_downsampling_factors[i] != 1: - encoder = DownsampledZipformerEncoder( - encoder, - input_dim=encoder_dims[i - 1] if i > 0 else encoder_dims[0], - output_dim=encoder_dims[i], - downsample=zipformer_downsampling_factors[i], - ) - encoders.append(encoder) - self.encoders = nn.ModuleList(encoders) - - # initializes self.skip_layers and self.skip_modules - self._init_skip_modules() - - self.downsample_output = AttentionDownsample( - encoder_dims[-1], encoder_dims[-1], downsample=output_downsampling_factor - ) - - def _get_layer_skip_dropout_prob(self): - if not self.training: - return 0.0 - batch_count = self.batch_count - min_dropout_prob = 0.025 - - if batch_count > self.warmup_end: - return min_dropout_prob - else: - return 0.5 - (batch_count / self.warmup_end) * (0.5 - min_dropout_prob) - - def _init_skip_modules(self): - """ - If self.zipformer_downsampling_factors = (1, 2, 4, 8, 4, 2), then at the input of layer - indexed 4 (in zero indexing), which has subsampling_factor=4, we combine the output of - layers 2 and 3; and at the input of layer indexed 5, which has subsampling_factor=2, - we combine the outputs of layers 1 and 4. - """ - skip_layers = [] - skip_modules = [] - z = self.zipformer_downsampling_factors - for i in range(len(z)): - if i <= 1 or z[i - 1] <= z[i]: - skip_layers.append(None) - skip_modules.append(SimpleCombinerIdentity()) - else: - # TEMP - for j in range(i - 2, -1, -1): - if z[j] <= z[i] or j == 0: - # TEMP logging statement. - logging.info( - f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will " - f"combine the outputs of layers {j} and {i-1}, with downsampling_factors={z[j]} and {z[i-1]}." - ) - skip_layers.append(j) - skip_modules.append( - SimpleCombiner( - self.encoder_dims[j], - self.encoder_dims[i - 1], - min_weight=(0.0, 0.25), - ) - ) - break - self.skip_layers = skip_layers - self.skip_modules = nn.ModuleList(skip_modules) - - def get_feature_masks(self, x: torch.Tensor) -> List[float]: - # Note: The actual return type is Union[List[float], List[Tensor]], - # but to make torch.jit.script() work, we use List[float] - """ - In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of - randomized feature masks, one per encoder. - On e.g. 15% of frames, these masks will zero out all encoder dims larger than - some supplied number, e.g. >256, so in effect on those frames we are using - a smaller encoder dim. - - We generate the random masks at this level because we want the 2 masks to 'agree' - all the way up the encoder stack. This will mean that the 1st mask will have - mask values repeated self.zipformer_downsampling_factors times. - - Args: - x: the embeddings (needed for the shape and dtype and device), of shape - (num_frames, batch_size, encoder_dims0) - """ - num_encoders = len(self.encoder_dims) - if torch.jit.is_scripting() or not self.training: - return [1.0] * num_encoders - - (num_frames0, batch_size, _encoder_dims0) = x.shape - - assert self.encoder_dims[0] == _encoder_dims0, ( - self.encoder_dims, - _encoder_dims0, - ) - - max_downsampling_factor = max(self.zipformer_downsampling_factors) - - num_frames_max = num_frames0 + max_downsampling_factor - 1 - - feature_mask_dropout_prob = 0.15 - - # frame_mask_max shape: (num_frames_max, batch_size, 1) - frame_mask_max = ( - torch.rand(num_frames_max, batch_size, 1, device=x.device) - > feature_mask_dropout_prob - ).to(x.dtype) - - feature_masks = [] - for i in range(num_encoders): - ds = self.zipformer_downsampling_factors[i] - upsample_factor = max_downsampling_factor // ds - - frame_mask = ( - frame_mask_max.unsqueeze(1) - .expand(num_frames_max, upsample_factor, batch_size, 1) - .reshape(num_frames_max * upsample_factor, batch_size, 1) - ) - num_frames = (num_frames0 + ds - 1) // ds - frame_mask = frame_mask[:num_frames] - feature_mask = torch.ones( - num_frames, - batch_size, - self.encoder_dims[i], - dtype=x.dtype, - device=x.device, - ) - u = self.encoder_unmasked_dims[i] - feature_mask[:, :, u:] *= frame_mask - feature_masks.append(feature_mask) - - return feature_masks - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - x: - The input tensor. Its shape is (batch_size, seq_len, feature_dim). - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - `x` before padding. - chunk_size: - The chunk size used in evaluation mode. - Returns: - Return a tuple containing 2 tensors: - - embeddings: its shape is (batch_size, output_seq_len, encoder_dims[-1]) - - lengths, a tensor of shape (batch_size,) containing the number - of frames in `embeddings` before padding. - """ - x = self.encoder_embed(x) - - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - lengths = (x_lens - 7) >> 1 - assert x.size(0) == lengths.max().item(), (x.shape, lengths, lengths.max()) - mask = make_pad_mask(lengths) - - outputs = [] - feature_masks = self.get_feature_masks(x) - - if self.training: - # Training mode - max_ds = max(self.zipformer_downsampling_factors) - # Generate dynamic chunk-wise attention mask during training - max_len = x.size(0) // max_ds - short_chunk_size = self.short_chunk_size // max_ds - chunk_size = torch.randint(1, max_len, (1,)).item() - if chunk_size > (max_len * self.short_chunk_threshold): - # Full attention - chunk_size = x.size(0) - else: - # Chunk-wise attention - chunk_size = chunk_size % short_chunk_size + 1 - chunk_size *= max_ds - else: - chunk_size = self.decode_chunk_size - # Evaluation mode - for ds in self.zipformer_downsampling_factors: - assert chunk_size % ds == 0, (chunk_size, ds) - - attn_mask = ~subsequent_chunk_mask( - size=x.size(0), - chunk_size=chunk_size, - num_left_chunks=self.num_left_chunks, - device=x.device, - ) - - for i, (module, skip_module) in enumerate( - zip(self.encoders, self.skip_modules) - ): - ds = self.zipformer_downsampling_factors[i] - k = self.skip_layers[i] - if isinstance(k, int): - layer_skip_dropout_prob = self._get_layer_skip_dropout_prob() - if torch.jit.is_scripting(): - x = skip_module(outputs[k], x) - elif (not self.training) or random.random() > layer_skip_dropout_prob: - x = skip_module(outputs[k], x) - x = module( - x, - feature_mask=feature_masks[i], - src_key_padding_mask=None if mask is None else mask[..., ::ds], - attn_mask=attn_mask[::ds, ::ds], - ) - outputs.append(x) - - x = self.downsample_output(x) - # class Downsample has this rounding behavior.. - assert self.output_downsampling_factor == 2, self.output_downsampling_factor - lengths = (lengths + 1) >> 1 - - x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - return x, lengths - - def streaming_forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - states: List[Tensor], - ) -> Tuple[Tensor, Tensor, List[Tensor]]: - """ - Args: - x: - The input tensor. Its shape is (batch_size, seq_len, feature_dim). - seq_len is the input chunk length. - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - `x` before padding. - states: - A list of 7 * num_encoders elements: - ``states[0:num_encoders]`` is the cached numbers of past frames. - ``states[num_encoders:2*num_encoders]`` is the cached average tensors. - ``states[2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. - ``states[3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. - ``states[4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. - ``states[5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. - ``states[6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. - - Returns: - Return a tuple containing 3 tensors: - - embeddings: its shape is (batch_size, output_seq_len, encoder_dims[-1]) - - lengths, a tensor of shape (batch_size,) containing the number - of frames in `embeddings` before padding. - - updated states. - """ - assert len(states) == 7 * self.num_encoders, (len(states), self.num_encoders) - - cached_len = states[: self.num_encoders] - cached_avg = states[self.num_encoders : 2 * self.num_encoders] - cached_key = states[2 * self.num_encoders : 3 * self.num_encoders] - cached_val = states[3 * self.num_encoders : 4 * self.num_encoders] - cached_val2 = states[4 * self.num_encoders : 5 * self.num_encoders] - cached_conv1 = states[5 * self.num_encoders : 6 * self.num_encoders] - cached_conv2 = states[6 * self.num_encoders : 7 * self.num_encoders] - - x = self.encoder_embed(x) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - lengths = (x_lens - 7) >> 1 - assert x.size(0) == lengths.max().item(), (x.shape, lengths, lengths.max()) - - outputs = [] - new_cached_len = [] - new_cached_avg = [] - new_cached_key = [] - new_cached_val = [] - new_cached_val2 = [] - new_cached_conv1 = [] - new_cached_conv2 = [] - - for i, (module, skip_module) in enumerate( - zip(self.encoders, self.skip_modules) - ): - k = self.skip_layers[i] - if isinstance(k, int): - x = skip_module(outputs[k], x) - x, len_avg, avg, key, val, val2, conv1, conv2 = module.streaming_forward( - x, - cached_len=cached_len[i], - cached_avg=cached_avg[i], - cached_key=cached_key[i], - cached_val=cached_val[i], - cached_val2=cached_val2[i], - cached_conv1=cached_conv1[i], - cached_conv2=cached_conv2[i], - ) - outputs.append(x) - # Update caches - new_cached_len.append(len_avg) - new_cached_avg.append(avg) - new_cached_key.append(key) - new_cached_val.append(val) - new_cached_val2.append(val2) - new_cached_conv1.append(conv1) - new_cached_conv2.append(conv2) - - x = self.downsample_output(x) - # class Downsample has this rounding behavior.. - assert self.output_downsampling_factor == 2, self.output_downsampling_factor - lengths = (lengths + 1) >> 1 - - x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - new_states = ( - new_cached_len - + new_cached_avg - + new_cached_key - + new_cached_val - + new_cached_val2 - + new_cached_conv1 - + new_cached_conv2 - ) - return x, lengths, new_states - - @torch.jit.export - def get_init_state( - self, - device: torch.device = torch.device("cpu"), - ) -> List[Tensor]: - """Get initial states. - A list of 7 * num_encoders elements: - ``states[0:num_encoders]`` is the cached numbers of past frames. - ``states[num_encoders:2*num_encoders]`` is the cached average tensors. - ``states[2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. - ``states[3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. - ``states[4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. - ``states[5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. - ``states[6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. - """ - cached_len = [] - cached_avg = [] - cached_key = [] - cached_val = [] - cached_val2 = [] - cached_conv1 = [] - cached_conv2 = [] - - left_context_len = self.decode_chunk_size * self.num_left_chunks - - for i, encoder in enumerate(self.encoders): - num_layers = encoder.num_layers - ds = self.zipformer_downsampling_factors[i] - - len_avg = torch.zeros(num_layers, 1, dtype=torch.int64, device=device) - cached_len.append(len_avg) - - avg = torch.zeros(num_layers, 1, encoder.d_model, device=device) - cached_avg.append(avg) - - key = torch.zeros( - num_layers, - left_context_len // ds, - 1, - encoder.attention_dim, - device=device, - ) - cached_key.append(key) - - val = torch.zeros( - num_layers, - left_context_len // ds, - 1, - encoder.attention_dim // 2, - device=device, - ) - cached_val.append(val) - - val2 = torch.zeros( - num_layers, - left_context_len // ds, - 1, - encoder.attention_dim // 2, - device=device, - ) - cached_val2.append(val2) - - conv1 = torch.zeros( - num_layers, - 1, - encoder.d_model, - encoder.cnn_module_kernel - 1, - device=device, - ) - cached_conv1.append(conv1) - - conv2 = torch.zeros( - num_layers, - 1, - encoder.d_model, - encoder.cnn_module_kernel - 1, - device=device, - ) - cached_conv2.append(conv2) - - states = ( - cached_len - + cached_avg - + cached_key - + cached_val - + cached_val2 - + cached_conv1 - + cached_conv2 - ) - return states - - -class ZipformerEncoderLayer(nn.Module): - """ - ZipformerEncoderLayer is made up of self-attn, feedforward and convolution networks. - - Args: - d_model: the number of expected features in the input (required). - nhead: the number of heads in the multiheadattention models (required). - feedforward_dim: the dimension of the feedforward network model (default=2048). - dropout: the dropout value (default=0.1). - cnn_module_kernel (int): Kernel size of convolution module. - - Examples:: - >>> encoder_layer = ZipformerEncoderLayer(d_model=512, nhead=8) - >>> src = torch.rand(10, 32, 512) - >>> pos_emb = torch.rand(32, 19, 512) - >>> out = encoder_layer(src, pos_emb) - """ - - def __init__( - self, - d_model: int, - attention_dim: int, - nhead: int, - feedforward_dim: int = 2048, - dropout: float = 0.1, - cnn_module_kernel: int = 31, - pos_dim: int = 4, - ) -> None: - super(ZipformerEncoderLayer, self).__init__() - - self.d_model = d_model - self.attention_dim = attention_dim - self.cnn_module_kernel = cnn_module_kernel - - # will be written to, see set_batch_count() - self.batch_count = 0 - - self.self_attn = RelPositionMultiheadAttention( - d_model, - attention_dim, - nhead, - pos_dim, - dropout=0.0, - ) - - self.pooling = PoolingModule(d_model) - - self.feed_forward1 = FeedforwardModule(d_model, feedforward_dim, dropout) - - self.feed_forward2 = FeedforwardModule(d_model, feedforward_dim, dropout) - - self.feed_forward3 = FeedforwardModule(d_model, feedforward_dim, dropout) - - self.conv_module1 = ConvolutionModule(d_model, cnn_module_kernel) - - self.conv_module2 = ConvolutionModule(d_model, cnn_module_kernel) - - self.norm_final = BasicNorm(d_model) - - self.bypass_scale = nn.Parameter(torch.tensor(0.5)) - - # try to ensure the output is close to zero-mean (or at least, zero-median). - self.balancer = ActivationBalancer( - d_model, - channel_dim=-1, - min_positive=0.45, - max_positive=0.55, - max_abs=6.0, - ) - self.whiten = Whiten( - num_groups=1, whitening_limit=5.0, prob=(0.025, 0.25), grad_scale=0.01 - ) - - def get_bypass_scale(self): - if torch.jit.is_scripting() or not self.training: - return self.bypass_scale - if random.random() < 0.1: - # ensure we get grads if self.bypass_scale becomes out of range - return self.bypass_scale - # hardcode warmup period for bypass scale - warmup_period = 20000.0 - initial_clamp_min = 0.75 - final_clamp_min = 0.25 - if self.batch_count > warmup_period: - clamp_min = final_clamp_min - else: - clamp_min = initial_clamp_min - (self.batch_count / warmup_period) * ( - initial_clamp_min - final_clamp_min - ) - return self.bypass_scale.clamp(min=clamp_min, max=1.0) - - def get_dynamic_dropout_rate(self): - # return dropout rate for the dynamic modules (self_attn, pooling, convolution); this - # starts at 0.2 and rapidly decreases to 0. Its purpose is to keep the training stable - # at the beginning, by making the network focus on the feedforward modules. - if torch.jit.is_scripting() or not self.training: - return 0.0 - warmup_period = 2000.0 - initial_dropout_rate = 0.2 - final_dropout_rate = 0.0 - if self.batch_count > warmup_period: - return final_dropout_rate - else: - return initial_dropout_rate - ( - initial_dropout_rate * final_dropout_rate - ) * (self.batch_count / warmup_period) - - def forward( - self, - src: Tensor, - pos_emb: Tensor, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - """ - Pass the input through the encoder layer. - - Args: - src: the sequence to the encoder layer (required). - pos_emb: Positional embedding tensor (required). - src_mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). - batch_split: if not None, this layer will only be applied to - - Shape: - src: (S, N, E). - pos_emb: (N, 2*S-1, E) - src_mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, N is the batch size, E is the feature number - """ - src_orig = src - - # macaron style feed forward module - src = src + self.feed_forward1(src) - - # dropout rate for submodules that interact with time. - dynamic_dropout = self.get_dynamic_dropout_rate() - - # pooling module - if torch.jit.is_scripting(): - src = src + self.pooling(src, src_key_padding_mask=src_key_padding_mask) - elif random.random() >= dynamic_dropout: - src = src + self.pooling(src, src_key_padding_mask=src_key_padding_mask) - - if torch.jit.is_scripting(): - src_att, attn_weights = self.self_attn( - src, - pos_emb=pos_emb, - attn_mask=attn_mask, - key_padding_mask=src_key_padding_mask, - ) - src = src + src_att - - src = src + self.conv_module1( - src, src_key_padding_mask=src_key_padding_mask - ) - - src = src + self.feed_forward2(src) - - src = src + self.self_attn.forward2(src, attn_weights) - - src = src + self.conv_module2( - src, src_key_padding_mask=src_key_padding_mask - ) - else: - use_self_attn = random.random() >= dynamic_dropout - if use_self_attn: - src_att, attn_weights = self.self_attn( - src, - pos_emb=pos_emb, - attn_mask=attn_mask, - key_padding_mask=src_key_padding_mask, - ) - src = src + src_att - - if random.random() >= dynamic_dropout: - src = src + self.conv_module1( - src, src_key_padding_mask=src_key_padding_mask - ) - - src = src + self.feed_forward2(src) - - if use_self_attn: - src = src + self.self_attn.forward2(src, attn_weights) - - if random.random() >= dynamic_dropout: - src = src + self.conv_module2( - src, src_key_padding_mask=src_key_padding_mask - ) - - src = src + self.feed_forward3(src) - - src = self.norm_final(self.balancer(src)) - - delta = src - src_orig - - src = src_orig + delta * self.get_bypass_scale() - - return self.whiten(src) - - def streaming_forward( - self, - src: Tensor, - pos_emb: Tensor, - cached_len: Tensor, - cached_avg: Tensor, - cached_key: Tensor, - cached_val: Tensor, - cached_val2: Tensor, - cached_conv1: Tensor, - cached_conv2: Tensor, - ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: - """ - Pass the input through the encoder layer. - - Args: - src: the sequence to the encoder layer (required). - pos_emb: Positional embedding tensor (required). - cached_len: processed number of past frames. - cached_avg: cached average of past frames. - cached_key: cached key tensor of left context for the first attention module. - cached_val: cached value tensor of left context for the first attention module. - cached_val2: cached value tensor of left context for the second attention module. - cached_conv1: cached left context for the first convolution module. - cached_conv2: cached left context for the second convolution module. - - Shape: - src: (S, N, E). - pos_emb: (N, left_context_len+2*S-1, E) - cached_len: (N,) - N is the batch size. - cached_avg: (N, C). - N is the batch size, C is the feature dimension. - cached_key: (left_context_len, N, K). - N is the batch size, K is the key dimension. - cached_val: (left_context_len, N, V). - N is the batch size, V is the key dimension. - cached_val2: (left_context_len, N, V). - N is the batch size, V is the key dimension. - cached_conv1: (N, C, kernel_size-1). - N is the batch size, C is the convolution channels. - cached_conv2: (N, C, kernel_size-1). - N is the batch size, C is the convolution channels. - """ - src_orig = src - - # macaron style feed forward module - src = src + self.feed_forward1(src) - - src_pool, cached_len, cached_avg = self.pooling.streaming_forward( - src, - cached_len=cached_len, - cached_avg=cached_avg, - ) - src = src + src_pool - - ( - src_attn, - attn_weights, - cached_key, - cached_val, - ) = self.self_attn.streaming_forward( - src, - pos_emb=pos_emb, - cached_key=cached_key, - cached_val=cached_val, - ) - src = src + src_attn - - src_conv, cached_conv1 = self.conv_module1.streaming_forward( - src, - cache=cached_conv1, - ) - src = src + src_conv - - src = src + self.feed_forward2(src) - - src_attn, cached_val2 = self.self_attn.streaming_forward2( - src, - attn_weights, - cached_val=cached_val2, - ) - src = src + src_attn - - src_conv, cached_conv2 = self.conv_module2.streaming_forward( - src, - cache=cached_conv2, - ) - src = src + src_conv - - src = src + self.feed_forward3(src) - - src = self.norm_final(self.balancer(src)) - - delta = src - src_orig - - src = src_orig + delta * self.bypass_scale - - return ( - src, - cached_len, - cached_avg, - cached_key, - cached_val, - cached_val2, - cached_conv1, - cached_conv2, - ) - - -class ZipformerEncoder(nn.Module): - r"""ZipformerEncoder is a stack of N encoder layers - - Args: - encoder_layer: an instance of the ZipformerEncoderLayer() class (required). - num_layers: the number of sub-encoder-layers in the encoder (required). - - Examples:: - >>> encoder_layer = ZipformerEncoderLayer(d_model=512, nhead=8) - >>> zipformer_encoder = ZipformerEncoder(encoder_layer, num_layers=6) - >>> src = torch.rand(10, 32, 512) - >>> out = zipformer_encoder(src) - """ - - def __init__( - self, - encoder_layer: nn.Module, - num_layers: int, - dropout: float, - warmup_begin: float, - warmup_end: float, - ) -> None: - super().__init__() - # will be written to, see set_batch_count() Note: in inference time this - # may be zero but should be treated as large, we can check if - # self.training is true. - self.batch_count = 0 - self.warmup_begin = warmup_begin - self.warmup_end = warmup_end - # module_seed is for when we need a random number that is unique to the module but - # shared across jobs. It's used to randomly select how many layers to drop, - # so that we can keep this consistent across worker tasks (for efficiency). - self.module_seed = torch.randint(0, 1000, ()).item() - - self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model, dropout) - - self.layers = nn.ModuleList( - [copy.deepcopy(encoder_layer) for i in range(num_layers)] - ) - self.num_layers = num_layers - - self.d_model = encoder_layer.d_model - self.attention_dim = encoder_layer.attention_dim - self.cnn_module_kernel = encoder_layer.cnn_module_kernel - - assert 0 <= warmup_begin <= warmup_end, (warmup_begin, warmup_end) - - delta = (1.0 / num_layers) * (warmup_end - warmup_begin) - cur_begin = warmup_begin - for i in range(num_layers): - self.layers[i].warmup_begin = cur_begin - cur_begin += delta - self.layers[i].warmup_end = cur_begin - - def get_layers_to_drop(self, rnd_seed: int): - ans = set() - if not self.training: - return ans - - batch_count = self.batch_count - num_layers = len(self.layers) - - def get_layerdrop_prob(layer: int) -> float: - layer_warmup_begin = self.layers[layer].warmup_begin - layer_warmup_end = self.layers[layer].warmup_end - - initial_layerdrop_prob = 0.5 - final_layerdrop_prob = 0.05 - - if batch_count == 0: - # As a special case, if batch_count == 0, return 0 (drop no - # layers). This is rather ugly, I'm afraid; it is intended to - # enable our scan_pessimistic_batches_for_oom() code to work correctly - # so if we are going to get OOM it will happen early. - # also search for 'batch_count' with quotes in this file to see - # how we initialize the warmup count to a random number between - # 0 and 10. - return 0.0 - elif batch_count < layer_warmup_begin: - return initial_layerdrop_prob - elif batch_count > layer_warmup_end: - return final_layerdrop_prob - else: - # linearly interpolate - t = (batch_count - layer_warmup_begin) / layer_warmup_end - assert 0.0 <= t < 1.001, t - return initial_layerdrop_prob + t * ( - final_layerdrop_prob - initial_layerdrop_prob - ) - - shared_rng = random.Random(batch_count + self.module_seed) - independent_rng = random.Random(rnd_seed) - - layerdrop_probs = [get_layerdrop_prob(i) for i in range(num_layers)] - tot = sum(layerdrop_probs) - # Instead of drawing the samples independently, we first randomly decide - # how many layers to drop out, using the same random number generator between - # jobs so that all jobs drop out the same number (this is for speed). - # Then we use an approximate approach to drop out the individual layers - # with their specified probs while reaching this exact target. - num_to_drop = int(tot) + int(shared_rng.random() < (tot - int(tot))) - - layers = list(range(num_layers)) - independent_rng.shuffle(layers) - - # go through the shuffled layers until we get the required number of samples. - if num_to_drop > 0: - for layer in itertools.cycle(layers): - if independent_rng.random() < layerdrop_probs[layer]: - ans.add(layer) - if len(ans) == num_to_drop: - break - if shared_rng.random() < 0.005 or __name__ == "__main__": - logging.info( - f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, " - f"batch_count={batch_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}" - ) - return ans - - def forward( - self, - src: Tensor, - # Note: The type of feature_mask should be Union[float, Tensor], - # but to make torch.jit.script() work, we use `float` here - feature_mask: float = 1.0, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - r"""Pass the input through the encoder layers in turn. - - Args: - src: the sequence to the encoder (required). - feature_mask: something that broadcasts with src, that we'll multiply `src` - by at every layer. - mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). - - Shape: - src: (S, N, E). - pos_emb: (N, 2*S-1, E) - mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number - - Returns: (x, x_no_combine), both of shape (S, N, E) - """ - pos_emb = self.encoder_pos(src) - output = src - - if torch.jit.is_scripting(): - layers_to_drop = [] - else: - rnd_seed = src.numel() + random.randint(0, 1000) - layers_to_drop = self.get_layers_to_drop(rnd_seed) - - output = output * feature_mask - - for i, mod in enumerate(self.layers): - if not torch.jit.is_scripting(): - if i in layers_to_drop: - continue - output = mod( - output, - pos_emb, - attn_mask=attn_mask, - src_key_padding_mask=src_key_padding_mask, - ) - - output = output * feature_mask - - return output - - @torch.jit.export - def streaming_forward( - self, - src: Tensor, - cached_len: Tensor, - cached_avg: Tensor, - cached_key: Tensor, - cached_val: Tensor, - cached_val2: Tensor, - cached_conv1: Tensor, - cached_conv2: Tensor, - ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: - r"""Pass the input through the encoder layers in turn. - - Args: - src: the sequence to the encoder (required). - cached_len: number of past frames. - cached_avg: cached average of past frames. - cached_key: cached key tensor for first attention module. - cached_val: cached value tensor for first attention module. - cached_val2: cached value tensor for second attention module. - cached_conv1: cached left contexts for the first convolution module. - cached_conv2: cached left contexts for the second convolution module. - - Shape: - src: (S, N, E). - cached_len: (num_layers,) - cached_avg: (num_layers, N, C). - N is the batch size, C is the feature dimension. - cached_key: (num_layers, left_context_len, N, K). - N is the batch size, K is the key dimension. - cached_val: (num_layers, left_context_len, N, V). - N is the batch size, V is the key dimension. - cached_val2: (num_layers, left_context_len, N, V). - N is the batch size, V is the key dimension. - cached_conv1: (num_layers, N, C, kernel_size-1). - N is the batch size, C is the convolution channels. - cached_conv2: (num_layers, N, C, kernel_size-1). - N is the batch size, C is the convolution channels. - - Returns: A tuple of 8 tensors: - - output tensor - - updated cached number of past frames. - - updated cached average of past frames. - - updated cached key tensor of of the first attention module. - - updated cached value tensor of of the first attention module. - - updated cached value tensor of of the second attention module. - - updated cached left contexts of the first convolution module. - - updated cached left contexts of the second convolution module. - """ - assert cached_len.size(0) == self.num_layers, ( - cached_len.size(0), - self.num_layers, - ) - assert cached_avg.size(0) == self.num_layers, ( - cached_avg.size(0), - self.num_layers, - ) - assert cached_key.size(0) == self.num_layers, ( - cached_key.size(0), - self.num_layers, - ) - assert cached_val.size(0) == self.num_layers, ( - cached_val.size(0), - self.num_layers, - ) - assert cached_val2.size(0) == self.num_layers, ( - cached_val2.size(0), - self.num_layers, - ) - assert cached_conv1.size(0) == self.num_layers, ( - cached_conv1.size(0), - self.num_layers, - ) - assert cached_conv2.size(0) == self.num_layers, ( - cached_conv2.size(0), - self.num_layers, - ) - - left_context_len = cached_key.shape[1] - pos_emb = self.encoder_pos(src, left_context_len) - output = src - - new_cached_len = [] - new_cached_avg = [] - new_cached_key = [] - new_cached_val = [] - new_cached_val2 = [] - new_cached_conv1 = [] - new_cached_conv2 = [] - for i, mod in enumerate(self.layers): - output, len_avg, avg, key, val, val2, conv1, conv2 = mod.streaming_forward( - output, - pos_emb, - cached_len=cached_len[i], - cached_avg=cached_avg[i], - cached_key=cached_key[i], - cached_val=cached_val[i], - cached_val2=cached_val2[i], - cached_conv1=cached_conv1[i], - cached_conv2=cached_conv2[i], - ) - # Update caches - new_cached_len.append(len_avg) - new_cached_avg.append(avg) - new_cached_key.append(key) - new_cached_val.append(val) - new_cached_val2.append(val2) - new_cached_conv1.append(conv1) - new_cached_conv2.append(conv2) - - return ( - output, - torch.stack(new_cached_len, dim=0), - torch.stack(new_cached_avg, dim=0), - torch.stack(new_cached_key, dim=0), - torch.stack(new_cached_val, dim=0), - torch.stack(new_cached_val2, dim=0), - torch.stack(new_cached_conv1, dim=0), - torch.stack(new_cached_conv2, dim=0), - ) - - -class DownsampledZipformerEncoder(nn.Module): - r""" - DownsampledZipformerEncoder is a zipformer encoder evaluated at a reduced frame rate, - after convolutional downsampling, and then upsampled again at the output, and combined - with the origin input, so that the output has the same shape as the input. - """ - - def __init__( - self, encoder: nn.Module, input_dim: int, output_dim: int, downsample: int - ): - super(DownsampledZipformerEncoder, self).__init__() - self.downsample_factor = downsample - self.downsample = AttentionDownsample(input_dim, output_dim, downsample) - self.encoder = encoder - self.num_layers = encoder.num_layers - self.d_model = encoder.d_model - self.attention_dim = encoder.attention_dim - self.cnn_module_kernel = encoder.cnn_module_kernel - self.upsample = SimpleUpsample(output_dim, downsample) - self.out_combiner = SimpleCombiner( - input_dim, output_dim, min_weight=(0.0, 0.25) - ) - - def forward( - self, - src: Tensor, - # Note: the type of feature_mask should be Unino[float, Tensor], - # but to make torch.jit.script() happ, we use float here - feature_mask: float = 1.0, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - r"""Downsample, go through encoder, upsample. - - Args: - src: the sequence to the encoder (required). - feature_mask: something that broadcasts with src, that we'll multiply `src` - by at every layer. feature_mask is expected to be already downsampled by - self.downsample_factor. - attn_mask: attention mask (optional). Should be downsampled already. - src_key_padding_mask: the mask for the src keys per batch (optional). Should be downsampled already. - - Shape: - src: (S, N, E). - attn_mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number - - Returns: output of shape (S, N, F) where F is the number of output features - (output_dim to constructor) - """ - src_orig = src - src = self.downsample(src) - - src = self.encoder( - src, - feature_mask=feature_mask, - attn_mask=attn_mask, - src_key_padding_mask=src_key_padding_mask, - ) - src = self.upsample(src) - # remove any extra frames that are not a multiple of downsample_factor - src = src[: src_orig.shape[0]] - - return self.out_combiner(src_orig, src) - - def streaming_forward( - self, - src: Tensor, - cached_len: Tensor, - cached_avg: Tensor, - cached_key: Tensor, - cached_val: Tensor, - cached_val2: Tensor, - cached_conv1: Tensor, - cached_conv2: Tensor, - ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: - r"""Downsample, go through encoder, upsample. - - Args: - src: the sequence to the encoder (required). - cached_avg: cached average value of past frames. - cached_len: length of past frames. - cached_key: cached key tensor for the first attention module. - cached_val: cached value tensor for the first attention module. - cached_val2: cached value tensor for the second attention module. - cached_conv1: cached left context for the first convolution module. - cached_conv2: cached left context for the second convolution module. - - Shape: - src: (S, N, E). - cached_len: (N,) - N is the batch size. - cached_avg: (num_layers, N, C). - N is the batch size, C is the feature dimension. - cached_key: (num_layers, left_context_len, N, K). - N is the batch size, K is the key dimension. - cached_val: (num_layers, left_context_len, N, V). - N is the batch size, V is the key dimension. - cached_val2: (num_layers, left_context_len, N, V). - N is the batch size, V is the key dimension. - cached_conv1: (num_layers, N, C, kernel_size-1). - N is the batch size, C is the convolution channels. - cached_conv2: (num_layers, N, C, kernel_size-1). - N is the batch size, C is the convolution channels. - Returns: output of shape (S, N, F) where F is the number of output features - (output_dim to constructor) - """ - src_orig = src - src = self.downsample(src) - - ( - src, - cached_len, - cached_avg, - cached_key, - cached_val, - cached_val2, - cached_conv1, - cached_conv2, - ) = self.encoder.streaming_forward( - src, - cached_len=cached_len, - cached_avg=cached_avg, - cached_key=cached_key, - cached_val=cached_val, - cached_val2=cached_val2, - cached_conv1=cached_conv1, - cached_conv2=cached_conv2, - ) - src = self.upsample(src) - # remove any extra frames that are not a multiple of downsample_factor - src = src[: src_orig.shape[0]] - - return ( - self.out_combiner(src_orig, src), - cached_len, - cached_avg, - cached_key, - cached_val, - cached_val2, - cached_conv1, - cached_conv2, - ) - - -class AttentionDownsample(torch.nn.Module): - """ - Does downsampling with attention, by weighted sum, and a projection.. - """ - - def __init__(self, in_channels: int, out_channels: int, downsample: int): - super(AttentionDownsample, self).__init__() - self.query = nn.Parameter(torch.randn(in_channels) * (in_channels**-0.5)) - - # fill in the extra dimensions with a projection of the input - if out_channels > in_channels: - self.extra_proj = nn.Linear( - in_channels * downsample, out_channels - in_channels, bias=False - ) - else: - self.extra_proj = None - self.downsample = downsample - - def forward(self, src: Tensor) -> Tensor: - """ - x: (seq_len, 1, in_channels) - Returns a tensor of shape - ( (seq_len+downsample-1)//downsample, batch_size, out_channels) - """ - (seq_len, batch_size, in_channels) = src.shape - ds = self.downsample - d_seq_len = (seq_len + ds - 1) // ds - - # Pad to an exact multiple of self.downsample - if seq_len != d_seq_len * ds: - # right-pad src, repeating the last element. - pad = d_seq_len * ds - seq_len - src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) - src = torch.cat((src, src_extra), dim=0) - assert src.shape[0] == d_seq_len * ds, (src.shape[0], d_seq_len, ds) - - src = src.reshape(d_seq_len, ds, batch_size, in_channels) - scores = (src * self.query).sum(dim=-1, keepdim=True) - - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04) - - weights = scores.softmax(dim=1) - - # ans1 is the first `in_channels` channels of the output - ans = (src * weights).sum(dim=1) - src = src.permute(0, 2, 1, 3).reshape(d_seq_len, batch_size, ds * in_channels) - - if self.extra_proj is not None: - ans2 = self.extra_proj(src) - ans = torch.cat((ans, ans2), dim=2) - return ans - - -class SimpleUpsample(torch.nn.Module): - """ - A very simple form of upsampling that mostly just repeats the input, but - also adds a position-specific bias. - """ - - def __init__(self, num_channels: int, upsample: int): - super(SimpleUpsample, self).__init__() - self.bias = nn.Parameter(torch.randn(upsample, num_channels) * 0.01) - - def forward(self, src: Tensor) -> Tensor: - """ - x: (seq_len, batch_size, num_channels) - Returns a tensor of shape - ( (seq_len*upsample), batch_size, num_channels) - """ - upsample = self.bias.shape[0] - (seq_len, batch_size, num_channels) = src.shape - src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels) - src = src + self.bias.unsqueeze(1) - src = src.reshape(seq_len * upsample, batch_size, num_channels) - return src - - -class SimpleCombinerIdentity(nn.Module): - def __init__(self, *args, **kwargs): - super().__init__() - - def forward(self, src1: Tensor, src2: Tensor) -> Tensor: - return src1 - - -class SimpleCombiner(torch.nn.Module): - """ - A very simple way of combining 2 vectors of 2 different dims, via a - learned weighted combination in the shared part of the dim. - Args: - dim1: the dimension of the first input, e.g. 256 - dim2: the dimension of the second input, e.g. 384. - The output will have the same dimension as dim2. - """ - - def __init__(self, dim1: int, dim2: int, min_weight: Tuple[float] = (0.0, 0.0)): - super(SimpleCombiner, self).__init__() - assert dim2 >= dim1, (dim2, dim1) - self.weight1 = nn.Parameter(torch.zeros(())) - self.min_weight = min_weight - - def forward(self, src1: Tensor, src2: Tensor) -> Tensor: - """ - src1: (*, dim1) - src2: (*, dim2) - - Returns: a tensor of shape (*, dim2) - """ - assert src1.shape[:-1] == src2.shape[:-1], (src1.shape, src2.shape) - - weight1 = self.weight1 - if not torch.jit.is_scripting(): - if ( - self.training - and random.random() < 0.25 - and self.min_weight != (0.0, 0.0) - ): - weight1 = weight1.clamp( - min=self.min_weight[0], max=1.0 - self.min_weight[1] - ) - - src1 = src1 * weight1 - src2 = src2 * (1.0 - weight1) - - src1_dim = src1.shape[-1] - src2_dim = src2.shape[-1] - if src1_dim != src2_dim: - if src1_dim < src2_dim: - src1 = torch.nn.functional.pad(src1, (0, src2_dim - src1_dim)) - else: - src1 = src1[:src2_dim] - - return src1 + src2 - - -class RelPositionalEncoding(torch.nn.Module): - """Relative positional encoding module. - - See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py - - Args: - d_model: Embedding dimension. - dropout_rate: Dropout rate. - max_len: Maximum input length. - - """ - - def __init__( - self, - d_model: int, - dropout_rate: float, - max_len: int = 5000, - ) -> None: - """Construct a PositionalEncoding object.""" - super(RelPositionalEncoding, self).__init__() - self.d_model = d_model - self.dropout = torch.nn.Dropout(dropout_rate) - self.pe = None - self.extend_pe(torch.tensor(0.0).expand(max_len)) - - def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None: - """Reset the positional encodings.""" - x_size_left = x.size(0) + left_context_len - if self.pe is not None: - # self.pe contains both positive and negative parts - # the length of self.pe is 2 * input_len - 1 - if self.pe.size(1) >= x_size_left * 2 - 1: - # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - # Suppose `i` means to the position of query vector and `j` means the - # position of key vector. We use positive relative positions when keys - # are to the left (i>j) and negative relative positions otherwise (i Tensor: - """Add positional encoding. - - Args: - x (torch.Tensor): Input tensor (time, batch, `*`). - left_context_len: (int): Length of cached left context. - - Returns: - torch.Tensor: Encoded tensor (batch, left_context_len + 2*time-1, `*`). - - """ - self.extend_pe(x, left_context_len) - x_size_left = x.size(0) + left_context_len - pos_emb = self.pe[ - :, - self.pe.size(1) // 2 - - x_size_left - + 1 : self.pe.size(1) // 2 # noqa E203 - + x.size(0), - ] - return self.dropout(pos_emb) - - -class RelPositionMultiheadAttention(nn.Module): - r"""Multi-Head Attention layer with relative position encoding - - This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", - we have to write up the differences. - - - Args: - embed_dim: total dimension of the model. - attention_dim: dimension in the attention module, may be less or more than embed_dim - but must be a multiple of num_heads. - num_heads: parallel attention heads. - dropout: a Dropout layer on attn_output_weights. Default: 0.0. - - Examples:: - - >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) - >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) - """ - - def __init__( - self, - embed_dim: int, - attention_dim: int, - num_heads: int, - pos_dim: int, - dropout: float = 0.0, - ) -> None: - super(RelPositionMultiheadAttention, self).__init__() - self.embed_dim = embed_dim - self.attention_dim = attention_dim - self.num_heads = num_heads - self.dropout = dropout - self.head_dim = attention_dim // num_heads - self.pos_dim = pos_dim - assert self.head_dim % 2 == 0, self.head_dim - assert self.head_dim * num_heads == attention_dim, ( - self.head_dim, - num_heads, - attention_dim, - ) - - # the initial_scale is supposed to take over the "scaling" factor of - # head_dim ** -0.5, dividing it between the query and key. - in_proj_dim = ( - 2 * attention_dim # query, key - + attention_dim // 2 # value - + pos_dim * num_heads # positional encoding query - ) - - self.in_proj = ScaledLinear( - embed_dim, in_proj_dim, bias=True, initial_scale=self.head_dim**-0.25 - ) - - # self.whiten_values is applied on the values in forward(); - # it just copies the keys but prevents low-rank distribution by modifying grads. - self.whiten_values = Whiten( - num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025, - ) - self.whiten_keys = Whiten( - num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025, - ) - - # linear transformation for positional encoding. - self.linear_pos = ScaledLinear( - embed_dim, num_heads * pos_dim, bias=False, initial_scale=0.05 - ) - - # the following are for diagnosics only, see --print-diagnostics option. - # they only copy their inputs. - self.copy_pos_query = Identity() - self.copy_query = Identity() - - self.out_proj = ScaledLinear( - attention_dim // 2, embed_dim, bias=True, initial_scale=0.05 - ) - - self.in_proj2 = nn.Linear(embed_dim, attention_dim // 2, bias=False) - self.out_proj2 = ScaledLinear( - attention_dim // 2, embed_dim, bias=True, initial_scale=0.05 - ) - # self.whiten_values2 is applied on the values in forward2() - self.whiten_values2 = Whiten( - num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025, - ) - - def forward( - self, - x: Tensor, - pos_emb: Tensor, - key_padding_mask: Optional[Tensor] = None, - attn_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor]: - r""" - Args: - x: input to be projected to query, key, value - pos_emb: Positional embedding tensor - key_padding_mask: if provided, specified padding elements in the key will - be ignored by the attention. When given a binary mask and a value is True, - the corresponding value on the attention layer will be ignored. When given - a byte mask and a value is non-zero, the corresponding value on the attention - layer will be ignored - attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all - the batches while a 3D mask allows to specify a different mask for the entries of each batch. - - Shape: - - Inputs: - - x: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. - If a ByteTensor is provided, the non-zero positions will be ignored while the position - with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the - value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, - S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked - positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend - while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` - is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor - is provided, it will be added to the attention weight. - - - Returns: (attn_output, attn_weights) - - - attn_output: :math:`(S, N, E)` where S is the sequence length, N is the batch size, - E is the embedding dimension. - - attn_weights: :math:`(N * N, S, S)` where N is the batch size, H is the num-heads - and S is the sequence length. - """ - x, weights = self.multi_head_attention_forward( - self.in_proj(x), - self.linear_pos(pos_emb), - self.attention_dim, - self.num_heads, - self.dropout, - self.out_proj.weight, - self.out_proj.bias, - training=self.training, - key_padding_mask=key_padding_mask, - attn_mask=attn_mask, - ) - return x, weights - - def streaming_forward( - self, - x: Tensor, - pos_emb: Tensor, - cached_key: Tensor, - cached_val: Tensor, - ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - r""" - Args: - x: input to be projected to query, key, value - pos_emb: Positional embedding tensor - - Shape: - - Inputs: - - x: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - cached_key: :math:`(left_context_len, N, K)`, where N is the batch size, K is the key dimension. - - cached_val: :math:`(left_context_len, N, V)`, where N is the batch size, V is the value dimension. - - - Returns: (attn_output, attn_weights, cached_key, cached_val) - - - attn_output: :math:`(S, N, E)` where S is the sequence length, N is the batch size, - E is the embedding dimension. - - attn_weights: :math:`(N * N, S, S)` where N is the batch size, H is the num-heads - and S is the sequence length. - - cached_key: :math:`(left_context_len, N, K)`, updated cached attention key tensor of - left context - - cached_val: :math:`(left_context_len, N, K)`, updated cached attention value tensor of - """ - ( - x, - weights, - cached_key, - cached_val, - ) = self.streaming_multi_head_attention_forward( - self.in_proj(x), - self.linear_pos(pos_emb), - self.attention_dim, - self.num_heads, - self.out_proj.weight, - self.out_proj.bias, - cached_key=cached_key, - cached_val=cached_val, - ) - return x, weights, cached_key, cached_val - - def multi_head_attention_forward( - self, - x_proj: Tensor, - pos: Tensor, - attention_dim: int, - num_heads: int, - dropout_p: float, - out_proj_weight: Tensor, - out_proj_bias: Tensor, - training: bool = True, - key_padding_mask: Optional[Tensor] = None, - attn_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor]: - r""" - Args: - x_proj: the projected input, to be split into query, key, value. - pos: head-specific biases arising from the positional embeddings. - attention_dim: dimension inside attention mechanism - num_heads: parallel attention heads. - dropout_p: probability of an element to be zeroed. - out_proj_weight, out_proj_bias: the output projection weight and bias. - training: apply dropout if is ``True``. - key_padding_mask: if provided, specified padding elements in the key will - be ignored by the attention. This is an binary mask. When the value is True, - the corresponding value on the attention layer will be filled with -inf. - attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all - the batches while a 3D mask allows to specify a different mask for the entries of each batch. - - Shape: - Inputs: - - x: :math:`(L, N, 7 * A // 2)` where L is the target sequence length, N is the batch size, A is - the attention dimension. Will be split into (query, key, value, pos). - - pos: :math:`(N, 2*L-1, A//2)` or :math:`(1, 2*L-1, A//2)` where L is the sequence - length, N is the batch size, and A is the attention dim. - - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. - If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions - will be unchanged. If a BoolTensor is provided, the positions with the - value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, - S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked - positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend - while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` - are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor - is provided, it will be added to the attention weight. - - Outputs: - - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, - E is the embedding dimension. - - attn_weights: :math:`(N * H, S, S)` where N is the batch size, - H is the num-heads, S is the sequence length. - """ - - seq_len, bsz, _ = x_proj.size() - - head_dim = attention_dim // num_heads - pos_dim = self.pos_dim # positional-encoding dim per head - assert ( - head_dim * num_heads == attention_dim - ), f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}, {attention_dim}" - - # self-attention - q = x_proj[..., 0:attention_dim] - k = x_proj[..., attention_dim : 2 * attention_dim] - value_dim = attention_dim // 2 - v = x_proj[..., 2 * attention_dim : 2 * attention_dim + value_dim] - # p is the position-encoding query, its dimension is num_heads*pos_dim.. - p = x_proj[..., 2 * attention_dim + value_dim :] - - k = self.whiten_keys(k) # does nothing in the forward pass. - v = self.whiten_values(v) # does nothing in the forward pass. - q = self.copy_query(q) # for diagnostics only, does nothing. - p = self.copy_pos_query(p) # for diagnostics only, does nothing. - - if attn_mask is not None: - assert ( - attn_mask.dtype == torch.float32 - or attn_mask.dtype == torch.float64 - or attn_mask.dtype == torch.float16 - or attn_mask.dtype == torch.uint8 - or attn_mask.dtype == torch.bool - ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( - attn_mask.dtype - ) - if attn_mask.dtype == torch.uint8: - warnings.warn( - "Byte tensor for attn_mask is deprecated. Use bool tensor instead." - ) - attn_mask = attn_mask.to(torch.bool) - - if attn_mask.dim() == 2: - attn_mask = attn_mask.unsqueeze(0) - if list(attn_mask.size()) != [1, seq_len, seq_len]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") - elif attn_mask.dim() == 3: - if list(attn_mask.size()) != [ - bsz * num_heads, - seq_len, - seq_len, - ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") - else: - raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) - ) - # attn_mask's dim is 3 now. - - # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: - warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." - ) - key_padding_mask = key_padding_mask.to(torch.bool) - - q = q.reshape(seq_len, bsz, num_heads, head_dim) - p = p.reshape(seq_len, bsz, num_heads, pos_dim) - k = k.reshape(seq_len, bsz, num_heads, head_dim) - v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1) - - if key_padding_mask is not None: - assert key_padding_mask.size(0) == bsz, "{} == {}".format( - key_padding_mask.size(0), bsz - ) - assert key_padding_mask.size(1) == seq_len, "{} == {}".format( - key_padding_mask.size(1), seq_len - ) - - q = q.permute(1, 2, 0, 3) # (batch, head, time1, head_dim) - p = p.permute(1, 2, 0, 3) # (batch, head, time1, pos_dim) - k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - - seq_len2 = 2 * seq_len - 1 - pos = pos.reshape(1, seq_len2, num_heads, pos_dim).permute(0, 2, 3, 1) - # pos shape now: (batch, head, pos_dim, seq_len2) - - # (batch, head, time1, pos_dim) x (1, head, pos_dim, seq_len2) -> (batch, head, time1, seq_len2) - # [where seq_len2 represents relative position.] - pos_weights = torch.matmul(p, pos) - # the following .as_strided() expression converts the last axis of pos_weights from relative - # to absolute position. I don't know whether I might have got the time-offsets backwards or - # not, but let this code define which way round it is supposed to be. - if torch.jit.is_tracing(): - (batch_size, num_heads, time1, n) = pos_weights.shape - rows = torch.arange(start=time1 - 1, end=-1, step=-1) - cols = torch.arange(seq_len) - rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) - indexes = rows + cols - pos_weights = pos_weights.reshape(-1, n) - pos_weights = torch.gather(pos_weights, dim=1, index=indexes) - pos_weights = pos_weights.reshape(batch_size, num_heads, time1, seq_len) - else: - pos_weights = pos_weights.as_strided( - (bsz, num_heads, seq_len, seq_len), - ( - pos_weights.stride(0), - pos_weights.stride(1), - pos_weights.stride(2) - pos_weights.stride(3), - pos_weights.stride(3), - ), - storage_offset=pos_weights.stride(3) * (seq_len - 1), - ) - - # caution: they are really scores at this point. - attn_output_weights = torch.matmul(q, k) + pos_weights - - if not torch.jit.is_scripting(): - if training and random.random() < 0.1: - # This is a harder way of limiting the attention scores to not be too large. - # It incurs a penalty if any of them has an absolute value greater than 50.0. - # this should be outside the normal range of the attention scores. We use - # this mechanism instead of, say, a limit on entropy, because once the entropy - # gets very small gradients through the softmax can become very small, and - # some mechanisms like that become ineffective. - attn_output_weights = penalize_abs_values_gt( - attn_output_weights, limit=25.0, penalty=1.0e-04 - ) - - # attn_output_weights: (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, seq_len, seq_len - ) - - if attn_mask is not None: - if attn_mask.dtype == torch.bool: - attn_output_weights = attn_output_weights.masked_fill( - attn_mask, float("-inf") - ) - else: - attn_output_weights = attn_output_weights + attn_mask - - if key_padding_mask is not None: - attn_output_weights = attn_output_weights.view( - bsz, num_heads, seq_len, seq_len - ) - attn_output_weights = attn_output_weights.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2), - float("-inf"), - ) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, seq_len, seq_len - ) - - # Using this version of softmax, defined in scaling.py, - # should save a little of the memory used in backprop by, if - # we are in automatic mixed precision mode (amp) == autocast, - # only storing the half-precision output for backprop purposes. - attn_output_weights = softmax(attn_output_weights, dim=-1) - - # If we are using chunk-wise attention mask and setting a limited - # num_left_chunks, the attention may only see the padding values which - # will also be masked out by `key_padding_mask`. At this circumstances, - # the whole column of `attn_output_weights` will be `-inf` - # (i.e. be `nan` after softmax). So we fill `0.0` at the masking - # positions to avoid invalid loss value below. - if ( - attn_mask is not None - and attn_mask.dtype == torch.bool - and key_padding_mask is not None - ): - if attn_mask.size(0) != 1: - attn_mask = attn_mask.view(bsz, num_heads, seq_len, seq_len) - combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) - else: - # attn_mask.shape == (1, tgt_len, src_len) - combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( - 1 - ).unsqueeze(2) - - attn_output_weights = attn_output_weights.view( - bsz, num_heads, seq_len, seq_len - ) - attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, seq_len, seq_len - ) - - attn_output_weights = nn.functional.dropout( - attn_output_weights, p=dropout_p, training=training - ) - - attn_output = torch.bmm(attn_output_weights, v) - assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim // 2] - attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(seq_len, bsz, attention_dim // 2) - ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) - - return attn_output, attn_output_weights - - def streaming_multi_head_attention_forward( - self, - x_proj: Tensor, - pos: Tensor, - attention_dim: int, - num_heads: int, - out_proj_weight: Tensor, - out_proj_bias: Tensor, - cached_key: Tensor, - cached_val: Tensor, - ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - r""" - Args: - x_proj: the projected input, to be split into query, key, value. - pos: head-specific biases arising from the positional embeddings. - attention_dim: dimension inside attention mechanism - num_heads: parallel attention heads. - out_proj_weight, out_proj_bias: the output projection weight and bias. - cached_key: cached attention key tensor of left context. - cached_val: cached attention value tensor of left context. - - Shape: - Inputs: - - x: :math:`(L, N, 7 * A // 2)` where L is the target sequence length, N is the batch size, A is - the attention dimension. Will be split into (query, key, value, pos). - - pos: :math:`(N, 2*L-1, A//2)` or :math:`(1, 2*L-1, A//2)` where L is the sequence - length, N is the batch size, and A is the attention dim. - If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions - will be unchanged. If a BoolTensor is provided, the positions with the - value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - Outputs: - - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, - E is the embedding dimension. - - attn_weights: :math:`(N * H, S, S)` where N is the batch size, - H is the num-heads, S is the sequence length. - - cached_key: :math:`(left_context_len, N, K)`, updated cached attention key tensor of left context. - - cached_val: :math:`(left_context_len, N, K)`, updated cached attention value tensor of left context. - """ - - seq_len, bsz, _ = x_proj.size() - - head_dim = attention_dim // num_heads - pos_dim = self.pos_dim # positional-encoding dim per head - assert ( - head_dim * num_heads == attention_dim - ), f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}, {attention_dim}" - - # self-attention - q = x_proj[..., 0:attention_dim] - k = x_proj[..., attention_dim : 2 * attention_dim] - value_dim = attention_dim // 2 - v = x_proj[..., 2 * attention_dim : 2 * attention_dim + value_dim] - # p is the position-encoding query, its dimension is num_heads*pos_dim.. - p = x_proj[..., 2 * attention_dim + value_dim :] - - left_context_len = cached_key.shape[0] - assert left_context_len > 0, left_context_len - assert cached_key.shape[0] == cached_val.shape[0], ( - cached_key.shape, - cached_val.shape, - ) - # Pad cached left contexts - k = torch.cat([cached_key, k], dim=0) - v = torch.cat([cached_val, v], dim=0) - # Update cached left contexts - cached_key = k[-left_context_len:, ...] - cached_val = v[-left_context_len:, ...] - - # The length of key and value - kv_len = k.shape[0] - - q = q.reshape(seq_len, bsz, num_heads, head_dim) - p = p.reshape(seq_len, bsz, num_heads, pos_dim) - k = k.reshape(kv_len, bsz, num_heads, head_dim) - v = v.reshape(kv_len, bsz * num_heads, head_dim // 2).transpose(0, 1) - - q = q.permute(1, 2, 0, 3) # (batch, head, time1, head_dim) - p = p.permute(1, 2, 0, 3) # (batch, head, time1, pos_dim) - k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - - seq_len2 = 2 * seq_len - 1 + left_context_len - pos = pos.reshape(1, seq_len2, num_heads, pos_dim).permute(0, 2, 3, 1) - # pos shape now: (batch, head, pos_dim, seq_len2) - - # (batch, head, time1, pos_dim) x (1, head, pos_dim, seq_len2) -> (batch, head, time1, seq_len2) - # [where seq_len2 represents relative position.] - pos_weights = torch.matmul(p, pos) - # the following .as_strided() expression converts the last axis of pos_weights from relative - # to absolute position. I don't know whether I might have got the time-offsets backwards or - # not, but let this code define which way round it is supposed to be. - if torch.jit.is_tracing(): - (batch_size, num_heads, time1, n) = pos_weights.shape - rows = torch.arange(start=time1 - 1, end=-1, step=-1) - cols = torch.arange(kv_len) - rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) - indexes = rows + cols - pos_weights = pos_weights.reshape(-1, n) - pos_weights = torch.gather(pos_weights, dim=1, index=indexes) - pos_weights = pos_weights.reshape(batch_size, num_heads, time1, kv_len) - else: - pos_weights = pos_weights.as_strided( - (bsz, num_heads, seq_len, kv_len), - ( - pos_weights.stride(0), - pos_weights.stride(1), - pos_weights.stride(2) - pos_weights.stride(3), - pos_weights.stride(3), - ), - storage_offset=pos_weights.stride(3) * (seq_len - 1), - ) - - # caution: they are really scores at this point. - attn_output_weights = torch.matmul(q, k) + pos_weights - - # attn_output_weights: (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, seq_len, kv_len) - - # Using this version of softmax, defined in scaling.py, - # should save a little of the memory used in backprop by, if - # we are in automatic mixed precision mode (amp) == autocast, - # only storing the half-precision output for backprop purposes. - attn_output_weights = softmax(attn_output_weights, dim=-1) - - attn_output = torch.bmm(attn_output_weights, v) - assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim // 2] - attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(seq_len, bsz, attention_dim // 2) - ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) - - return attn_output, attn_output_weights, cached_key, cached_val - - def forward2( - self, - x: Tensor, - attn_weights: Tensor, - ) -> Tensor: - """ - Second forward function, where we re-use the attn_weights returned by the first forward function - but with different input. - Args: - x: input, of shape (seq_len, batch_size, embed_dim) - attn_weights: attention weights returned by forward(), of shape (batch_size * num_heads, seq_len, seq_len) - Returns: - output of the same shape as x, i.e. (seq_len, batch_size, embed_dim) - """ - num_heads = self.num_heads - (seq_len, bsz, embed_dim) = x.shape - head_dim = self.attention_dim // num_heads - # v: (tgt_len, bsz, embed_dim // 2) - v = self.in_proj2(x) - v = self.whiten_values2(v) # does nothing in the forward pass. - v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1) - - # now v: (bsz * num_heads, seq_len, head_dim // 2) - attn_output = torch.bmm(attn_weights, v) - - if not torch.jit.is_scripting(): - if random.random() < 0.001 or __name__ == "__main__": - self._print_attn_stats(attn_weights, attn_output) - - # attn_output: (bsz * num_heads, seq_len, head_dim) - attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(seq_len, bsz, self.attention_dim // 2) - ) - # returned value is of shape (seq_len, bsz, embed_dim), like x. - return self.out_proj2(attn_output) - - def streaming_forward2( - self, - x: Tensor, - attn_weights: Tensor, - cached_val: Tensor, - ) -> Tuple[Tensor, Tensor]: - """ - Second forward function, where we re-use the attn_weights returned by the first forward function - but with different input. - Args: - x: input, of shape (seq_len, batch_size, embed_dim) - attn_weights: attention weights returned by forward(), of shape (batch_size * num_heads, seq_len, seq_len) - cached_val: cached attention value tensor of left context. - Returns: - - output of the same shape as x, i.e. (seq_len, batch_size, embed_dim) - - updated cached attention value tensor of left context. - """ - num_heads = self.num_heads - (seq_len, bsz, embed_dim) = x.shape - head_dim = self.attention_dim // num_heads - # v: (tgt_len, bsz, embed_dim // 2) - v = self.in_proj2(x) - - left_context_len = cached_val.shape[0] - assert left_context_len > 0, left_context_len - v = torch.cat([cached_val, v], dim=0) - cached_val = v[-left_context_len:] - - seq_len2 = left_context_len + seq_len - v = v.reshape(seq_len2, bsz * num_heads, head_dim // 2).transpose(0, 1) - - # now v: (bsz * num_heads, seq_len, head_dim // 2) - attn_output = torch.bmm(attn_weights, v) - - # attn_output: (bsz * num_heads, seq_len, head_dim) - attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(seq_len, bsz, self.attention_dim // 2) - ) - # returned value is of shape (seq_len, bsz, embed_dim), like x. - return self.out_proj2(attn_output), cached_val - - def _print_attn_stats(self, attn_weights: Tensor, attn_output: Tensor): - # attn_weights: (batch_size * num_heads, seq_len, seq_len) - # attn_output: (bsz * num_heads, seq_len, head_dim) - (n, seq_len, head_dim) = attn_output.shape - num_heads = self.num_heads - bsz = n // num_heads - - with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): - attn_weights = attn_weights.to(torch.float32) - attn_output = attn_output.to(torch.float32) - attn_weights_entropy = ( - -((attn_weights + 1.0e-20).log() * attn_weights) - .sum(dim=-1) - .reshape(bsz, num_heads, seq_len) - .mean(dim=(0, 2)) - ) - attn_output = attn_output.reshape(bsz, num_heads, seq_len, head_dim) - attn_output = attn_output.permute(1, 0, 2, 3).reshape( - num_heads, bsz * seq_len, head_dim - ) - attn_output_mean = attn_output.mean(dim=1, keepdim=True) - attn_output = attn_output - attn_output_mean - attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / ( - bsz * seq_len - ) - # attn_covar: (num_heads, head_dim, head_dim) - # eigs, _ = torch.symeig(attn_covar) - # logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") - - attn_covar = _diag(attn_covar).mean(dim=1) # (num_heads,) - embed_dim = self.in_proj2.weight.shape[1] - in_proj_covar = ( - self.in_proj2.weight.reshape(num_heads, head_dim, embed_dim) ** 2 - ).mean(dim=(1, 2)) - out_proj_covar = ( - self.out_proj2.weight.reshape(embed_dim, num_heads, head_dim) ** 2 - ).mean(dim=(0, 2)) - logging.info( - f"attn_weights_entropy = {attn_weights_entropy}, covar={attn_covar}, in_proj_covar={in_proj_covar}, out_proj_covar={out_proj_covar}" - ) - - -class PoolingModule(nn.Module): - """ - Averages the input over the time dimension and project with a square matrix. - """ - - def __init__(self, d_model: int): - super().__init__() - self.proj = ScaledLinear(d_model, d_model, initial_scale=0.1, bias=False) - - def forward( - self, - x: Tensor, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - """ - Args: - x: a Tensor of shape (T, N, C) - src_key_padding_mask: a Tensor of bool, of shape (N, T), with True in masked - positions. - - Returns: - - output, a Tensor of shape (T, N, C). - """ - if src_key_padding_mask is not None: - # False in padding positions - padding_mask = src_key_padding_mask.logical_not().to(x.dtype) # (N, T) - # Cumulated numbers of frames from start - cum_mask = padding_mask.cumsum(dim=1) # (N, T) - x = x.cumsum(dim=0) # (T, N, C) - pooling_mask = padding_mask / cum_mask - pooling_mask = pooling_mask.transpose(0, 1).contiguous().unsqueeze(-1) - # now pooling_mask: (T, N, 1) - x = x * pooling_mask # (T, N, C) - else: - num_frames = x.shape[0] - cum_mask = torch.arange(1, num_frames + 1).unsqueeze(1) # (T, 1) - x = x.cumsum(dim=0) # (T, N, C) - pooling_mask = (1.0 / cum_mask).unsqueeze(2) - # now pooling_mask: (T, N, 1) - x = x * pooling_mask - - x = self.proj(x) - return x - - def streaming_forward( - self, - x: Tensor, - cached_len: Tensor, - cached_avg: Tensor, - ) -> Tuple[Tensor, Tensor, Tensor]: - """ - Args: - x: a Tensor of shape (T, N, C) - cached_len: a Tensor of int, of shape (N,), containing the number of - past frames in batch. - cached_avg: a Tensor of shape (N, C), the average over all past frames - in batch. - - Returns: - A tuple of 2 tensors: - - output, a Tensor of shape (T, N, C). - - updated cached_avg, a Tensor of shape (N, C). - """ - x = x.cumsum(dim=0) # (T, N, C) - x = x + (cached_avg * cached_len.unsqueeze(1)).unsqueeze(0) - # Cumulated numbers of frames from start - cum_mask = torch.arange(1, x.size(0) + 1, device=x.device) - cum_mask = cum_mask.unsqueeze(1) + cached_len.unsqueeze(0) # (T, N) - pooling_mask = (1.0 / cum_mask).unsqueeze(2) - # now pooling_mask: (T, N, 1) - x = x * pooling_mask # (T, N, C) - - cached_len = cached_len + x.size(0) - cached_avg = x[-1] - - x = self.proj(x) - return x, cached_len, cached_avg - - -class FeedforwardModule(nn.Module): - """Feedforward module in Zipformer model.""" - - def __init__(self, d_model: int, feedforward_dim: int, dropout: float): - super(FeedforwardModule, self).__init__() - self.in_proj = nn.Linear(d_model, feedforward_dim) - self.balancer = ActivationBalancer( - feedforward_dim, channel_dim=-1, max_abs=10.0, min_prob=0.25 - ) - self.activation = DoubleSwish() - self.dropout = nn.Dropout(dropout) - self.out_proj = ScaledLinear(feedforward_dim, d_model, initial_scale=0.01) - - def forward(self, x: Tensor): - x = self.in_proj(x) - x = self.balancer(x) - x = self.activation(x) - x = self.dropout(x) - x = self.out_proj(x) - return x - - -class ConvolutionModule(nn.Module): - """ConvolutionModule in Zipformer model. - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py - - Args: - channels (int): The number of channels of conv layers. - kernel_size (int): Kernerl size of conv layers. - bias (bool): Whether to use bias in conv layers (default=True). - - """ - - def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: - """Construct an ConvolutionModule object.""" - super(ConvolutionModule, self).__init__() - # kernerl_size should be a odd number for 'SAME' padding - assert (kernel_size - 1) % 2 == 0, kernel_size - - self.pointwise_conv1 = nn.Conv1d( - channels, - 2 * channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, - ) - - # after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu). - # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, - # but sometimes, for some reason, for layer 0 the rms ends up being very large, - # between 50 and 100 for different channels. This will cause very peaky and - # sparse derivatives for the sigmoid gating function, which will tend to make - # the loss function not learn effectively. (for most layers the average absolute values - # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, - # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different - # layers, which likely breaks down as 0.5 for the "linear" half and - # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we - # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, - # it will be in a better position to start learning something, i.e. to latch onto - # the correct range. - self.deriv_balancer1 = ActivationBalancer( - 2 * channels, - channel_dim=1, - max_abs=10.0, - min_positive=0.05, - max_positive=1.0, - ) - - # Will pad cached left context - self.lorder = kernel_size - 1 - self.depthwise_conv = nn.Conv1d( - channels, - channels, - kernel_size, - stride=1, - padding=0, - groups=channels, - bias=bias, - ) - - self.deriv_balancer2 = ActivationBalancer( - channels, - channel_dim=1, - min_positive=0.05, - max_positive=1.0, - max_abs=20.0, - ) - - self.activation = DoubleSwish() - - self.pointwise_conv2 = ScaledConv1d( - channels, - channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, - initial_scale=0.05, - ) - - def forward( - self, - x: Tensor, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - """Compute convolution module. - - Args: - x: Input tensor (#time, batch, channels). - src_key_padding_mask: the mask for the src keys per batch (optional): - (batch, #time), contains bool in masked positions. - - Returns: - - Output tensor (#time, batch, channels). - """ - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - - # GLU mechanism - x = self.pointwise_conv1(x) # (batch, 2*channels, time) - - x = self.deriv_balancer1(x) - x = nn.functional.glu(x, dim=1) # (batch, channels, time) - - if src_key_padding_mask is not None: - x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) - - # 1D Depthwise Conv - # Make depthwise_conv causal by - # manualy padding self.lorder zeros to the left - x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) - x = self.depthwise_conv(x) - - x = self.deriv_balancer2(x) - x = self.activation(x) - - x = self.pointwise_conv2(x) # (batch, channel, time) - - return x.permute(2, 0, 1) - - def streaming_forward( - self, - x: Tensor, - cache: Tensor, - ) -> Tuple[Tensor, Tensor]: - """Compute convolution module. - - Args: - x: Input tensor (#time, batch, channels). - src_key_padding_mask: the mask for the src keys per batch: - (batch, #time), contains bool in masked positions. - cache: Cached left context for depthwise_conv, with shape of - (batch, channels, #kernel_size-1). Only used in real streaming decoding. - - Returns: - A tuple of 2 tensors: - - Output tensor (#time, batch, channels). - - New cached left context, with shape of (batch, channels, #kernel_size-1). - """ - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - - # GLU mechanism - x = self.pointwise_conv1(x) # (batch, 2*channels, time) - - x = self.deriv_balancer1(x) - x = nn.functional.glu(x, dim=1) # (batch, channels, time) - - # 1D Depthwise Conv - assert cache.shape == (x.size(0), x.size(1), self.lorder), ( - cache.shape, - (x.size(0), x.size(1), self.lorder), - ) - x = torch.cat([cache, x], dim=2) - # Update cache - cache = x[:, :, -self.lorder :] - x = self.depthwise_conv(x) - - x = self.deriv_balancer2(x) - x = self.activation(x) - - x = self.pointwise_conv2(x) # (batch, channel, time) - - return x.permute(2, 0, 1), cache - - -class Conv2dSubsampling(nn.Module): - """Convolutional 2D subsampling (to 1/4 length). - - Convert an input of shape (N, T, idim) to an output - with shape (N, T', odim), where - T' = (T-3)//2 - 2 == (T-7)//2 - - It is based on - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - layer1_channels: int = 8, - layer2_channels: int = 32, - layer3_channels: int = 128, - dropout: float = 0.1, - ) -> None: - """ - Args: - in_channels: - Number of channels in. The input shape is (N, T, in_channels). - Caution: It requires: T >=7, in_channels >=7 - out_channels - Output dim. The output shape is (N, (T-7)//2, out_channels) - layer1_channels: - Number of channels in layer1 - layer2_channels: - Number of channels in layer2 - layer3_channels: - Number of channels in layer3 - """ - assert in_channels >= 7, in_channels - super().__init__() - - self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, - out_channels=layer1_channels, - kernel_size=3, - padding=(0, 1), # (time, freq) - ), - ActivationBalancer(layer1_channels, channel_dim=1), - DoubleSwish(), - nn.Conv2d( - in_channels=layer1_channels, - out_channels=layer2_channels, - kernel_size=3, - stride=2, - padding=0, - ), - ActivationBalancer(layer2_channels, channel_dim=1), - DoubleSwish(), - nn.Conv2d( - in_channels=layer2_channels, - out_channels=layer3_channels, - kernel_size=3, - stride=(1, 2), # (time, freq) - ), - ActivationBalancer(layer3_channels, channel_dim=1), - DoubleSwish(), - ) - out_height = (((in_channels - 1) // 2) - 1) // 2 - self.out = ScaledLinear(out_height * layer3_channels, out_channels) - self.dropout = nn.Dropout(dropout) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Subsample x. - - Args: - x: - Its shape is (N, T, idim). - - Returns: - Return a tensor of shape (N, (T-7)//2, odim) - """ - # On entry, x is (N, T, idim) - x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) - x = self.conv(x) - # Now x is of shape (N, odim, (T-7)//2, ((idim-1)//2 - 1)//2) - b, c, t, f = x.size() - x = self.out(x.transpose(1, 2).reshape(b, t, c * f)) - # Now x is of shape (N, (T-7)//2, odim) - x = self.dropout(x) - return x - - -def _test_zipformer_main(): - feature_dim = 50 - batch_size = 5 - seq_len = 47 - feature_dim = 50 - # Just make sure the forward pass runs. - - c = Zipformer( - num_features=feature_dim, - encoder_dims=(64, 96), - encoder_unmasked_dims=(48, 64), - nhead=(4, 4), - decode_chunk_size=4, - ) - # Just make sure the forward pass runs. - f = c( - torch.randn(batch_size, seq_len, feature_dim), - torch.full((batch_size,), seq_len, dtype=torch.int64), - ) - assert ((seq_len - 7) // 2 + 1) // 2 == f[0].shape[1], (seq_len, f.shape[1]) - f[0].sum().backward() - c.eval() - f = c( - torch.randn(batch_size, seq_len, feature_dim), - torch.full((batch_size,), seq_len, dtype=torch.int64), - ) - f # to remove flake8 warnings - - -def _test_conv2d_subsampling(): - num_features = 80 - encoder_dims = 384 - dropout = 0.1 - encoder_embed = Conv2dSubsampling(num_features, encoder_dims, dropout=dropout) - for i in range(20, 40): - x = torch.rand(2, i, num_features) - y = encoder_embed(x) - assert (x.shape[1] - 7) // 2 == y.shape[1], (x.shape[1], y.shape[1]) - - -def _test_pooling_module(): - N, S, C = 2, 12, 32 - chunk_len = 4 - m = PoolingModule(d_model=C) - - # test chunk-wise forward with padding_mask - x = torch.randn(S, N, C) - y = m(x) - cached_len = torch.zeros(N, dtype=torch.int32) - cached_avg = torch.zeros(N, C) - for i in range(S // chunk_len): - start = i * chunk_len - end = start + chunk_len - x_chunk = x[start:end] - y_chunk, cached_len, cached_avg = m.streaming_forward( - x_chunk, - cached_len=cached_len, - cached_avg=cached_avg, - ) - assert torch.allclose(y_chunk, y[start:end]), (y_chunk, y[start:end]) - - -def _test_state_stack_unstack(): - m = Zipformer( - num_features=80, - encoder_dims=(64, 96), - encoder_unmasked_dims=(48, 64), - nhead=(4, 4), - zipformer_downsampling_factors=(4, 8), - num_left_chunks=2, - decode_chunk_size=8, - ) - s1 = m.get_init_state() - s2 = m.get_init_state() - states = stack_states([s1, s2]) - new_s1, new_s2 = unstack_states(states) - for i in range(m.num_encoders * 7): - for x, y in zip(s1[i], new_s1[i]): - assert torch.equal(x, y) - for x, y in zip(s2[i], new_s2[i]): - assert torch.equal(x, y) - - -if __name__ == "__main__": - logging.getLogger().setLevel(logging.INFO) - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - _test_zipformer_main() - _test_conv2d_subsampling() - _test_pooling_module() - _test_state_stack_unstack() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer.py new file mode 120000 index 000000000..ec183baa7 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer2.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer2.py deleted file mode 100644 index be9cd1608..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer2.py +++ /dev/null @@ -1,3144 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey,) -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -import itertools -import logging -import math -import random -import warnings -from typing import List, Optional, Tuple, Union - -import torch -from encoder_interface import EncoderInterface -from scaling import ( # not as in other dirs.. just scales down initial parameter values. - ActivationBalancer, - BasicNorm, - DoubleSwish, - Identity, - MaxEig, - ScaledConv1d, - ScaledLinear, - Whiten, - _diag, - penalize_abs_values_gt, - random_clamp, - softmax, -) -from torch import Tensor, nn -from zipformer import PoolingModule - -from icefall.utils import make_pad_mask, subsequent_chunk_mask - - -def stack_states(state_list: List[List[Tensor]]) -> List[Tensor]: - """Stack list of zipformer states that correspond to separate utterances - into a single emformer state, so that it can be used as an input for - zipformer when those utterances are formed into a batch. - - Note: - It is the inverse of :func:`unstack_states`. - - Args: - state_list: - Each element in state_list corresponding to the internal state - of the zipformer model for a single utterance. - ``states[i]`` is a list of 7 * num_encoders elements of i-th utterance. - ``states[i][0:num_encoders]`` is the cached numbers of past frames. - ``states[i][num_encoders:2*num_encoders]`` is the cached average tensors. - ``states[i][2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. - ``states[i][3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. - ``states[i][4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. - ``states[i][5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. - ``states[i][6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. - - Returns: - A new state corresponding to a batch of utterances. - See the input argument of :func:`unstack_states` for the meaning - of the returned tensor. - """ - batch_size = len(state_list) - assert len(state_list[0]) % 7 == 0, len(state_list[0]) - num_encoders = len(state_list[0]) // 7 - - cached_len = [] - cached_avg = [] - cached_key = [] - cached_val = [] - cached_val2 = [] - cached_conv1 = [] - cached_conv2 = [] - - # For cached_len - len_list = [state_list[n][0:num_encoders] for n in range(batch_size)] - for i in range(num_encoders): - # len_avg: (num_layers, batch_size) - len_avg = torch.cat([len_list[n][i] for n in range(batch_size)], dim=1) - cached_len.append(len_avg) - - # For cached_avg - avg_list = [ - state_list[n][num_encoders : 2 * num_encoders] for n in range(batch_size) - ] - for i in range(num_encoders): - # avg: (num_layers, batch_size, D) - avg = torch.cat([avg_list[n][i] for n in range(batch_size)], dim=1) - cached_avg.append(avg) - - # For cached_key - key_list = [ - state_list[n][2 * num_encoders : 3 * num_encoders] for n in range(batch_size) - ] - for i in range(num_encoders): - # key: (num_layers, left_context_size, batch_size, D) - key = torch.cat([key_list[n][i] for n in range(batch_size)], dim=2) - cached_key.append(key) - - # For cached_val - val_list = [ - state_list[n][3 * num_encoders : 4 * num_encoders] for n in range(batch_size) - ] - for i in range(num_encoders): - # val: (num_layers, left_context_size, batch_size, D) - val = torch.cat([val_list[n][i] for n in range(batch_size)], dim=2) - cached_val.append(val) - - # For cached_val2 - val2_list = [ - state_list[n][4 * num_encoders : 5 * num_encoders] for n in range(batch_size) - ] - for i in range(num_encoders): - # val2: (num_layers, left_context_size, batch_size, D) - val2 = torch.cat([val2_list[n][i] for n in range(batch_size)], dim=2) - cached_val2.append(val2) - - # For cached_conv1 - conv1_list = [ - state_list[n][5 * num_encoders : 6 * num_encoders] for n in range(batch_size) - ] - for i in range(num_encoders): - # conv1: (num_layers, batch_size, D, kernel-1) - conv1 = torch.cat([conv1_list[n][i] for n in range(batch_size)], dim=1) - cached_conv1.append(conv1) - - # For cached_conv2 - conv2_list = [ - state_list[n][6 * num_encoders : 7 * num_encoders] for n in range(batch_size) - ] - for i in range(num_encoders): - # conv2: (num_layers, batch_size, D, kernel-1) - conv2 = torch.cat([conv2_list[n][i] for n in range(batch_size)], dim=1) - cached_conv2.append(conv2) - - states = ( - cached_len - + cached_avg - + cached_key - + cached_val - + cached_val2 - + cached_conv1 - + cached_conv2 - ) - return states - - -def unstack_states(states: List[Tensor]) -> List[List[Tensor]]: - """Unstack the zipformer state corresponding to a batch of utterances - into a list of states, where the i-th entry is the state from the i-th - utterance in the batch. - - Note: - It is the inverse of :func:`stack_states`. - - Args: - states: - A list of 7 * num_encoders elements: - ``states[0:num_encoders]`` is the cached numbers of past frames. - ``states[num_encoders:2*num_encoders]`` is the cached average tensors. - ``states[2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. - ``states[3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. - ``states[4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. - ``states[5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. - ``states[6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. - - Returns: - A list of states. - ``states[i]`` is a list of 7 * num_encoders elements of i-th utterance. - """ - assert len(states) % 7 == 0, len(states) - num_encoders = len(states) // 7 - ( - cached_len, - cached_avg, - cached_key, - cached_val, - cached_val2, - cached_conv1, - cached_conv2, - ) = (states[i * num_encoders : (i + 1) * num_encoders] for i in range(7)) - - batch_size = cached_len[0].shape[1] - - len_list = [[] for _ in range(batch_size)] - for i in range(num_encoders): - # cached_len[i]: (num_layers, batch_size) - len_avg = cached_len[i].chunk(chunks=batch_size, dim=1) - for n in range(batch_size): - len_list[n].append(len_avg[n]) - - avg_list = [[] for _ in range(batch_size)] - for i in range(num_encoders): - # cached_avg[i]: (num_layers, batch_size, D) - avg = cached_avg[i].chunk(chunks=batch_size, dim=1) - for n in range(batch_size): - avg_list[n].append(avg[n]) - - key_list = [[] for _ in range(batch_size)] - for i in range(num_encoders): - # cached_key[i]: (num_layers, left_context, batch_size, D) - key = cached_key[i].chunk(chunks=batch_size, dim=2) - for n in range(batch_size): - key_list[n].append(key[n]) - - val_list = [[] for _ in range(batch_size)] - for i in range(num_encoders): - # cached_val[i]: (num_layers, left_context, batch_size, D) - val = cached_val[i].chunk(chunks=batch_size, dim=2) - for n in range(batch_size): - val_list[n].append(val[n]) - - val2_list = [[] for _ in range(batch_size)] - for i in range(num_encoders): - # cached_val2[i]: (num_layers, left_context, batch_size, D) - val2 = cached_val2[i].chunk(chunks=batch_size, dim=2) - for n in range(batch_size): - val2_list[n].append(val2[n]) - - conv1_list = [[] for _ in range(batch_size)] - for i in range(num_encoders): - # cached_conv1[i]: (num_layers, batch_size, D, kernel-1) - conv1 = cached_conv1[i].chunk(chunks=batch_size, dim=1) - for n in range(batch_size): - conv1_list[n].append(conv1[n]) - - conv2_list = [[] for _ in range(batch_size)] - for i in range(num_encoders): - # cached_conv2[i]: (num_layers, batch_size, D, kernel-1) - conv2 = cached_conv2[i].chunk(chunks=batch_size, dim=1) - for n in range(batch_size): - conv2_list[n].append(conv2[n]) - - state_list = [ - ( - len_list[i] - + avg_list[i] - + key_list[i] - + val_list[i] - + val2_list[i] - + conv1_list[i] - + conv2_list[i] - ) - for i in range(batch_size) - ] - return state_list - - -class Zipformer(EncoderInterface): - """ - Args: - num_features (int): Number of input features - d_model: (int,int): embedding dimension of 2 encoder stacks - attention_dim: (int,int): attention dimension of 2 encoder stacks - nhead (int, int): number of heads - dim_feedforward (int, int): feedforward dimension in 2 encoder stacks - num_encoder_layers (int): number of encoder layers - dropout (float): dropout rate - cnn_module_kernels (int): Kernel size of convolution module - warmup_batches (float): number of batches to warm up over - is_pnnx (bool): True if we are going to convert this model via pnnx. - """ - - def __init__( - self, - num_features: int, - output_downsampling_factor: int = 2, - encoder_dims: Tuple[int] = (384, 384), - attention_dim: Tuple[int] = (256, 256), - encoder_unmasked_dims: Tuple[int] = (256, 256), - zipformer_downsampling_factors: Tuple[int] = (2, 4), - nhead: Tuple[int] = (8, 8), - feedforward_dim: Tuple[int] = (1536, 2048), - num_encoder_layers: Tuple[int] = (12, 12), - dropout: float = 0.1, - cnn_module_kernels: Tuple[int] = (31, 31), - pos_dim: int = 4, - num_left_chunks: int = 4, - short_chunk_threshold: float = 0.75, - short_chunk_size: int = 50, - decode_chunk_size: int = 16, - warmup_batches: float = 4000.0, - is_pnnx: bool = False, - ) -> None: - super(Zipformer, self).__init__() - self.is_pnnx = is_pnnx - - self.num_features = num_features - assert 0 < encoder_dims[0] <= encoder_dims[1] - self.encoder_dims = encoder_dims - self.encoder_unmasked_dims = encoder_unmasked_dims - self.zipformer_downsampling_factors = zipformer_downsampling_factors - self.output_downsampling_factor = output_downsampling_factor - - self.num_left_chunks = num_left_chunks - self.short_chunk_threshold = short_chunk_threshold - self.short_chunk_size = short_chunk_size - - # Used in decoding - self.decode_chunk_size = decode_chunk_size - - self.left_context_len = self.decode_chunk_size * self.num_left_chunks - - # will be written to, see set_batch_count() - self.batch_count = 0 - self.warmup_end = warmup_batches - - for u, d in zip(encoder_unmasked_dims, encoder_dims): - assert u <= d, (u, d) - - # self.encoder_embed converts the input of shape (N, T, num_features) - # to the shape (N, (T - 7)//2, encoder_dims). - # That is, it does two things simultaneously: - # (1) subsampling: T -> (T - 7)//2 - # (2) embedding: num_features -> encoder_dims - self.encoder_embed = Conv2dSubsampling( - num_features, encoder_dims[0], dropout=dropout, is_pnnx=is_pnnx - ) - - # each one will be ZipformerEncoder or DownsampledZipformerEncoder - encoders = [] - - self.num_encoders = len(encoder_dims) - for i in range(self.num_encoders): - ds = zipformer_downsampling_factors[i] - encoder_layer = ZipformerEncoderLayer( - encoder_dims[i], - attention_dim[i], - nhead[i], - feedforward_dim[i], - dropout, - cnn_module_kernels[i], - pos_dim, - is_pnnx=self.is_pnnx, - left_context_len=self.left_context_len // ds, - x_size=self.decode_chunk_size // ds, - ) - - # For the segment of the warmup period, we let the Conv2dSubsampling - # layer learn something. Then we start to warm up the other encoders. - encoder = ZipformerEncoder( - encoder_layer, - num_encoder_layers[i], - dropout, - warmup_begin=warmup_batches * (i + 1) / (self.num_encoders + 1), - warmup_end=warmup_batches * (i + 2) / (self.num_encoders + 1), - is_pnnx=is_pnnx, - left_context_len=self.left_context_len // ds, - x_size=self.decode_chunk_size // ds, - ) - - if zipformer_downsampling_factors[i] != 1: - in_x_size = self.decode_chunk_size - encoder = DownsampledZipformerEncoder( - encoder, - input_dim=encoder_dims[i - 1] if i > 0 else encoder_dims[0], - output_dim=encoder_dims[i], - downsample=zipformer_downsampling_factors[i], - is_pnnx=is_pnnx, - left_context_len=self.left_context_len // ds, - in_x_size=in_x_size, - ) - encoders.append(encoder) - self.encoders = nn.ModuleList(encoders) - - # initializes self.skip_layers and self.skip_modules - self._init_skip_modules() - - self.downsample_output = AttentionDownsample( - encoder_dims[-1], - encoder_dims[-1], - downsample=output_downsampling_factor, - is_pnnx=is_pnnx, - in_x_size=self.decode_chunk_size, - ) - - def _get_layer_skip_dropout_prob(self): - if not self.training: - return 0.0 - batch_count = self.batch_count - min_dropout_prob = 0.025 - - if batch_count > self.warmup_end: - return min_dropout_prob - else: - return 0.5 - (batch_count / self.warmup_end) * (0.5 - min_dropout_prob) - - def _init_skip_modules(self): - """ - If self.zipformer_downsampling_factors = (1, 2, 4, 8, 4, 2), then at the input of layer - indexed 4 (in zero indexing), which has subsampling_factor=4, we combine the output of - layers 2 and 3; and at the input of layer indexed 5, which has subsampling_factor=2, - we combine the outputs of layers 1 and 4. - """ - skip_layers = [] - skip_modules = [] - z = self.zipformer_downsampling_factors - for i in range(len(z)): - if i <= 1 or z[i - 1] <= z[i]: - skip_layers.append(None) - skip_modules.append(SimpleCombinerIdentity()) - else: - # TEMP - for j in range(i - 2, -1, -1): - if z[j] <= z[i] or j == 0: - # TEMP logging statement. - logging.info( - f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will " - f"combine the outputs of layers {j} and {i-1}, with downsampling_factors={z[j]} and {z[i-1]}." - ) - skip_layers.append(j) - skip_modules.append( - SimpleCombiner( - self.encoder_dims[j], - self.encoder_dims[i - 1], - min_weight=(0.0, 0.25), - ) - ) - break - self.skip_layers = skip_layers - self.skip_modules = nn.ModuleList(skip_modules) - - def get_feature_masks(self, x: torch.Tensor) -> List[float]: - # Note: The actual return type is Union[List[float], List[Tensor]], - # but to make torch.jit.script() work, we use List[float] - """ - In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of - randomized feature masks, one per encoder. - On e.g. 15% of frames, these masks will zero out all encoder dims larger than - some supplied number, e.g. >256, so in effect on those frames we are using - a smaller encoder dim. - - We generate the random masks at this level because we want the 2 masks to 'agree' - all the way up the encoder stack. This will mean that the 1st mask will have - mask values repeated self.zipformer_downsampling_factors times. - - Args: - x: the embeddings (needed for the shape and dtype and device), of shape - (num_frames, batch_size, encoder_dims0) - """ - num_encoders = len(self.encoder_dims) - if torch.jit.is_scripting() or not self.training: - return [1.0] * num_encoders - - (num_frames0, batch_size, _encoder_dims0) = x.shape - - assert self.encoder_dims[0] == _encoder_dims0, ( - self.encoder_dims, - _encoder_dims0, - ) - - max_downsampling_factor = max(self.zipformer_downsampling_factors) - - num_frames_max = num_frames0 + max_downsampling_factor - 1 - - feature_mask_dropout_prob = 0.15 - - # frame_mask_max shape: (num_frames_max, batch_size, 1) - frame_mask_max = ( - torch.rand(num_frames_max, batch_size, 1, device=x.device) - > feature_mask_dropout_prob - ).to(x.dtype) - - feature_masks = [] - for i in range(num_encoders): - ds = self.zipformer_downsampling_factors[i] - upsample_factor = max_downsampling_factor // ds - - frame_mask = ( - frame_mask_max.unsqueeze(1) - .expand(num_frames_max, upsample_factor, batch_size, 1) - .reshape(num_frames_max * upsample_factor, batch_size, 1) - ) - num_frames = (num_frames0 + ds - 1) // ds - frame_mask = frame_mask[:num_frames] - feature_mask = torch.ones( - num_frames, - batch_size, - self.encoder_dims[i], - dtype=x.dtype, - device=x.device, - ) - u = self.encoder_unmasked_dims[i] - feature_mask[:, :, u:] *= frame_mask - feature_masks.append(feature_mask) - - return feature_masks - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - x: - The input tensor. Its shape is (batch_size, seq_len, feature_dim). - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - `x` before padding. - chunk_size: - The chunk size used in evaluation mode. - Returns: - Return a tuple containing 2 tensors: - - embeddings: its shape is (batch_size, output_seq_len, encoder_dims[-1]) - - lengths, a tensor of shape (batch_size,) containing the number - of frames in `embeddings` before padding. - """ - x = self.encoder_embed(x) - - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - lengths = (x_lens - 7) >> 1 - assert x.size(0) == lengths.max().item(), (x.shape, lengths, lengths.max()) - mask = make_pad_mask(lengths) - - outputs = [] - feature_masks = self.get_feature_masks(x) - - if self.training: - # Training mode - max_ds = max(self.zipformer_downsampling_factors) - # Generate dynamic chunk-wise attention mask during training - max_len = x.size(0) // max_ds - short_chunk_size = self.short_chunk_size // max_ds - chunk_size = torch.randint(1, max_len, (1,)).item() - if chunk_size > (max_len * self.short_chunk_threshold): - # Full attention - chunk_size = x.size(0) - else: - # Chunk-wise attention - chunk_size = chunk_size % short_chunk_size + 1 - chunk_size *= max_ds - else: - chunk_size = self.decode_chunk_size - # Evaluation mode - for ds in self.zipformer_downsampling_factors: - assert chunk_size % ds == 0, (chunk_size, ds) - - attn_mask = ~subsequent_chunk_mask( - size=x.size(0), - chunk_size=chunk_size, - num_left_chunks=self.num_left_chunks, - device=x.device, - ) - - for i, (module, skip_module) in enumerate( - zip(self.encoders, self.skip_modules) - ): - ds = self.zipformer_downsampling_factors[i] - k = self.skip_layers[i] - if isinstance(k, int): - layer_skip_dropout_prob = self._get_layer_skip_dropout_prob() - if torch.jit.is_scripting(): - x = skip_module(outputs[k], x) - elif (not self.training) or random.random() > layer_skip_dropout_prob: - x = skip_module(outputs[k], x) - x = module( - x, - feature_mask=feature_masks[i], - src_key_padding_mask=None if mask is None else mask[..., ::ds], - attn_mask=attn_mask[::ds, ::ds], - ) - outputs.append(x) - - x = self.downsample_output(x) - # class Downsample has this rounding behavior.. - assert self.output_downsampling_factor == 2, self.output_downsampling_factor - lengths = (lengths + 1) >> 1 - - x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - return x, lengths - - def streaming_forward( - self, - x: torch.Tensor, - states: List[Tensor], - ) -> Tuple[Tensor, List[Tensor]]: - """ - Args: - x: - The input tensor. Its shape is (batch_size, seq_len, feature_dim). - seq_len is the input chunk length. - states: - A list of 7 * num_encoders elements: - ``states[0:num_encoders]`` is the cached numbers of past frames. - ``states[num_encoders:2*num_encoders]`` is the cached average tensors. - ``states[2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. - ``states[3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. - ``states[4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. - ``states[5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. - ``states[6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. - - Returns: - Return a tuple containing 3 tensors: - - embeddings: its shape is (batch_size, output_seq_len, encoder_dims[-1]) - - lengths, a tensor of shape (batch_size,) containing the number - of frames in `embeddings` before padding. - - updated states. - """ - assert len(states) == 7 * self.num_encoders, (len(states), self.num_encoders) - - cached_len = states[: self.num_encoders] - cached_avg = states[self.num_encoders : 2 * self.num_encoders] - cached_key = states[2 * self.num_encoders : 3 * self.num_encoders] - cached_val = states[3 * self.num_encoders : 4 * self.num_encoders] - cached_val2 = states[4 * self.num_encoders : 5 * self.num_encoders] - cached_conv1 = states[5 * self.num_encoders : 6 * self.num_encoders] - cached_conv2 = states[6 * self.num_encoders : 7 * self.num_encoders] - - x = self.encoder_embed(x) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - outputs = [] - new_cached_len = [] - new_cached_avg = [] - new_cached_key = [] - new_cached_val = [] - new_cached_val2 = [] - new_cached_conv1 = [] - new_cached_conv2 = [] - - for i, (module, skip_module) in enumerate( - zip(self.encoders, self.skip_modules) - ): - k = self.skip_layers[i] - if isinstance(k, int): - x = skip_module(outputs[k], x) - x, len_avg, avg, key, val, val2, conv1, conv2 = module.streaming_forward( - x, - cached_len=cached_len[i], - cached_avg=cached_avg[i], - cached_key=cached_key[i], - cached_val=cached_val[i], - cached_val2=cached_val2[i], - cached_conv1=cached_conv1[i], - cached_conv2=cached_conv2[i], - ) - - outputs.append(x) - # Update caches - new_cached_len.append(len_avg) - new_cached_avg.append(avg) - new_cached_key.append(key) - new_cached_val.append(val) - new_cached_val2.append(val2) - new_cached_conv1.append(conv1) - new_cached_conv2.append(conv2) - - x = self.downsample_output(x) - # class Downsample has this rounding behavior.. - assert self.output_downsampling_factor == 2, self.output_downsampling_factor - - x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - new_states = ( - new_cached_len - + new_cached_avg - + new_cached_key - + new_cached_val - + new_cached_val2 - + new_cached_conv1 - + new_cached_conv2 - ) - return x, new_states - - @torch.jit.export - def get_init_state( - self, - device: torch.device = torch.device("cpu"), - ) -> List[Tensor]: - """Get initial states. - A list of 7 * num_encoders elements: - ``states[0:num_encoders]`` is the cached numbers of past frames. - ``states[num_encoders:2*num_encoders]`` is the cached average tensors. - ``states[2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. - ``states[3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. - ``states[4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. - ``states[5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. - ``states[6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. - """ - cached_len = [] - cached_avg = [] - cached_key = [] - cached_val = [] - cached_val2 = [] - cached_conv1 = [] - cached_conv2 = [] - - for i, encoder in enumerate(self.encoders): - num_layers = encoder.num_layers - ds = self.zipformer_downsampling_factors[i] - - len_avg = torch.zeros(num_layers, 1, device=device) - cached_len.append(len_avg) - - avg = torch.zeros(num_layers, 1, encoder.d_model, device=device) - cached_avg.append(avg) - - key = torch.zeros( - num_layers, - self.left_context_len // ds, - 1, - encoder.attention_dim, - device=device, - ) - cached_key.append(key) - - val = torch.zeros( - num_layers, - self.left_context_len // ds, - 1, - encoder.attention_dim // 2, - device=device, - ) - cached_val.append(val) - - val2 = torch.zeros( - num_layers, - self.left_context_len // ds, - 1, - encoder.attention_dim // 2, - device=device, - ) - cached_val2.append(val2) - - conv1 = torch.zeros( - num_layers, - 1, - encoder.d_model, - encoder.cnn_module_kernel - 1, - device=device, - ) - cached_conv1.append(conv1) - - conv2 = torch.zeros( - num_layers, - 1, - encoder.d_model, - encoder.cnn_module_kernel - 1, - device=device, - ) - cached_conv2.append(conv2) - - states = ( - cached_len - + cached_avg - + cached_key - + cached_val - + cached_val2 - + cached_conv1 - + cached_conv2 - ) - return states - - -class ZipformerEncoderLayer(nn.Module): - """ - ZipformerEncoderLayer is made up of self-attn, feedforward and convolution networks. - - Args: - d_model: the number of expected features in the input (required). - nhead: the number of heads in the multiheadattention models (required). - feedforward_dim: the dimension of the feedforward network model (default=2048). - dropout: the dropout value (default=0.1). - cnn_module_kernel (int): Kernel size of convolution module. - - Examples:: - >>> encoder_layer = ZipformerEncoderLayer(d_model=512, nhead=8) - >>> src = torch.rand(10, 32, 512) - >>> pos_emb = torch.rand(32, 19, 512) - >>> out = encoder_layer(src, pos_emb) - """ - - def __init__( - self, - d_model: int, - attention_dim: int, - nhead: int, - feedforward_dim: int = 2048, - dropout: float = 0.1, - cnn_module_kernel: int = 31, - pos_dim: int = 4, - is_pnnx: bool = False, - left_context_len: int = 0, - x_size: int = 0, - ) -> None: - super(ZipformerEncoderLayer, self).__init__() - - self.d_model = d_model - self.attention_dim = attention_dim - self.cnn_module_kernel = cnn_module_kernel - - # will be written to, see set_batch_count() - self.batch_count = 0 - - self.self_attn = RelPositionMultiheadAttention( - d_model, - attention_dim, - nhead, - pos_dim, - dropout=0.0, - is_pnnx=is_pnnx, - left_context_len=left_context_len, - x_size=x_size, - ) - - self.pooling = PoolingModule(d_model) - - self.feed_forward1 = FeedforwardModule(d_model, feedforward_dim, dropout) - - self.feed_forward2 = FeedforwardModule(d_model, feedforward_dim, dropout) - - self.feed_forward3 = FeedforwardModule(d_model, feedforward_dim, dropout) - - self.conv_module1 = ConvolutionModule( - d_model, cnn_module_kernel, is_pnnx=is_pnnx, x_size=x_size - ) - - self.conv_module2 = ConvolutionModule( - d_model, cnn_module_kernel, is_pnnx=is_pnnx, x_size=x_size - ) - - self.norm_final = BasicNorm(d_model) - - self.bypass_scale = nn.Parameter(torch.tensor(0.5)) - - # try to ensure the output is close to zero-mean (or at least, zero-median). - self.balancer = ActivationBalancer( - d_model, - channel_dim=-1, - min_positive=0.45, - max_positive=0.55, - max_abs=6.0, - ) - self.whiten = Whiten( - num_groups=1, whitening_limit=5.0, prob=(0.025, 0.25), grad_scale=0.01 - ) - - def get_bypass_scale(self): - if torch.jit.is_scripting() or not self.training: - return self.bypass_scale - if random.random() < 0.1: - # ensure we get grads if self.bypass_scale becomes out of range - return self.bypass_scale - # hardcode warmup period for bypass scale - warmup_period = 20000.0 - initial_clamp_min = 0.75 - final_clamp_min = 0.25 - if self.batch_count > warmup_period: - clamp_min = final_clamp_min - else: - clamp_min = initial_clamp_min - (self.batch_count / warmup_period) * ( - initial_clamp_min - final_clamp_min - ) - return self.bypass_scale.clamp(min=clamp_min, max=1.0) - - def get_dynamic_dropout_rate(self): - # return dropout rate for the dynamic modules (self_attn, pooling, convolution); this - # starts at 0.2 and rapidly decreases to 0. Its purpose is to keep the training stable - # at the beginning, by making the network focus on the feedforward modules. - if torch.jit.is_scripting() or not self.training: - return 0.0 - warmup_period = 2000.0 - initial_dropout_rate = 0.2 - final_dropout_rate = 0.0 - if self.batch_count > warmup_period: - return final_dropout_rate - else: - return initial_dropout_rate - ( - initial_dropout_rate * final_dropout_rate - ) * (self.batch_count / warmup_period) - - def forward( - self, - src: Tensor, - pos_emb: Tensor, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - """ - Pass the input through the encoder layer. - - Args: - src: the sequence to the encoder layer (required). - pos_emb: Positional embedding tensor (required). - src_mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). - batch_split: if not None, this layer will only be applied to - - Shape: - src: (S, N, E). - pos_emb: (N, 2*S-1, E) - src_mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, N is the batch size, E is the feature number - """ - src_orig = src - - # macaron style feed forward module - src = src + self.feed_forward1(src) - - # dropout rate for submodules that interact with time. - dynamic_dropout = self.get_dynamic_dropout_rate() - - # pooling module - if torch.jit.is_scripting(): - src = src + self.pooling(src, src_key_padding_mask=src_key_padding_mask) - elif random.random() >= dynamic_dropout: - src = src + self.pooling(src, src_key_padding_mask=src_key_padding_mask) - - if torch.jit.is_scripting(): - src_att, attn_weights = self.self_attn( - src, - pos_emb=pos_emb, - attn_mask=attn_mask, - key_padding_mask=src_key_padding_mask, - ) - src = src + src_att - - src = src + self.conv_module1( - src, src_key_padding_mask=src_key_padding_mask - ) - - src = src + self.feed_forward2(src) - - src = src + self.self_attn.forward2(src, attn_weights) - - src = src + self.conv_module2( - src, src_key_padding_mask=src_key_padding_mask - ) - else: - use_self_attn = random.random() >= dynamic_dropout - if use_self_attn: - src_att, attn_weights = self.self_attn( - src, - pos_emb=pos_emb, - attn_mask=attn_mask, - key_padding_mask=src_key_padding_mask, - ) - src = src + src_att - - if random.random() >= dynamic_dropout: - src = src + self.conv_module1( - src, src_key_padding_mask=src_key_padding_mask - ) - - src = src + self.feed_forward2(src) - - if use_self_attn: - src = src + self.self_attn.forward2(src, attn_weights) - - if random.random() >= dynamic_dropout: - src = src + self.conv_module2( - src, src_key_padding_mask=src_key_padding_mask - ) - - src = src + self.feed_forward3(src) - - src = self.norm_final(self.balancer(src)) - - delta = src - src_orig - - src = src_orig + delta * self.get_bypass_scale() - - return self.whiten(src) - - def streaming_forward( - self, - src: Tensor, - pos_emb: Tensor, - cached_len: Tensor, - cached_avg: Tensor, - cached_key: Tensor, - cached_val: Tensor, - cached_val2: Tensor, - cached_conv1: Tensor, - cached_conv2: Tensor, - ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: - """ - Pass the input through the encoder layer. - - Args: - src: the sequence to the encoder layer (required). - pos_emb: Positional embedding tensor (required). - cached_len: processed number of past frames. - cached_avg: cached average of past frames. - cached_key: cached key tensor of left context for the first attention module. - cached_val: cached value tensor of left context for the first attention module. - cached_val2: cached value tensor of left context for the second attention module. - cached_conv1: cached left context for the first convolution module. - cached_conv2: cached left context for the second convolution module. - - Shape: - src: (S, N, E). - pos_emb: (N, left_context_len+2*S-1, E) - cached_len: (N,) - N is the batch size. - cached_avg: (N, C). - N is the batch size, C is the feature dimension. - cached_key: (left_context_len, N, K). - N is the batch size, K is the key dimension. - cached_val: (left_context_len, N, V). - N is the batch size, V is the key dimension. - cached_val2: (left_context_len, N, V). - N is the batch size, V is the key dimension. - cached_conv1: (N, C, kernel_size-1). - N is the batch size, C is the convolution channels. - cached_conv2: (N, C, kernel_size-1). - N is the batch size, C is the convolution channels. - """ - src_orig = src - - # macaron style feed forward module - src = src + self.feed_forward1(src) - - src_pool, cached_len, cached_avg = self.pooling.streaming_forward( - src, - cached_len=cached_len, - cached_avg=cached_avg, - ) - src = src + src_pool - - ( - src_attn, - attn_weights, - cached_key, - cached_val, - ) = self.self_attn.streaming_forward( - src, - pos_emb=pos_emb, - cached_key=cached_key, - cached_val=cached_val, - ) - - src = src + src_attn - - src_conv, cached_conv1 = self.conv_module1.streaming_forward( - src, - cache=cached_conv1, - ) - - src = src + src_conv - - src = src + self.feed_forward2(src) - - src_attn, cached_val2 = self.self_attn.streaming_forward2( - src, - attn_weights, - cached_val=cached_val2, - ) - src = src + src_attn - - src_conv, cached_conv2 = self.conv_module2.streaming_forward( - src, - cache=cached_conv2, - ) - src = src + src_conv - - src = src + self.feed_forward3(src) - - src = self.norm_final(self.balancer(src)) - - delta = src - src_orig - - src = src_orig + delta * self.bypass_scale - - return ( - src, - cached_len, - cached_avg, - cached_key, - cached_val, - cached_val2, - cached_conv1, - cached_conv2, - ) - - -class ZipformerStateSelect(nn.Module): - """ncnn does not support selecting along batch index. - This class provides a workaround for it. We - need to change pnnx accordingly. - """ - - def __init__(self, i: int): - super().__init__() - self.i = i - - def forward(self, x: torch.Tensor): - return x[self.i] - - -class ZipformerEncoder(nn.Module): - r"""ZipformerEncoder is a stack of N encoder layers - - Args: - encoder_layer: an instance of the ZipformerEncoderLayer() class (required). - num_layers: the number of sub-encoder-layers in the encoder (required). - - Examples:: - >>> encoder_layer = ZipformerEncoderLayer(d_model=512, nhead=8) - >>> zipformer_encoder = ZipformerEncoder(encoder_layer, num_layers=6) - >>> src = torch.rand(10, 32, 512) - >>> out = zipformer_encoder(src) - """ - - def __init__( - self, - encoder_layer: nn.Module, - num_layers: int, - dropout: float, - warmup_begin: float, - warmup_end: float, - is_pnnx: bool = False, - x_size: int = 0, - left_context_len: int = 0, - ) -> None: - super().__init__() - # will be written to, see set_batch_count() Note: in inference time this - # may be zero but should be treated as large, we can check if - # self.training is true. - self.batch_count = 0 - self.warmup_begin = warmup_begin - self.warmup_end = warmup_end - # module_seed is for when we need a random number that is unique to the module but - # shared across jobs. It's used to randomly select how many layers to drop, - # so that we can keep this consistent across worker tasks (for efficiency). - self.module_seed = torch.randint(0, 1000, ()).item() - self.left_context_len = left_context_len - - self.encoder_pos = RelPositionalEncoding( - encoder_layer.d_model, - dropout, - is_pnnx=is_pnnx, - x_size=x_size, - left_context_len=left_context_len, - ) - - self.layers = nn.ModuleList( - [copy.deepcopy(encoder_layer) for i in range(num_layers)] - ) - self.num_layers = num_layers - - state_select_list = [] - for i in range(num_layers): - state_select_list.append(ZipformerStateSelect(i)) - self.state_select_list = nn.ModuleList(state_select_list) - - self.d_model = encoder_layer.d_model - self.attention_dim = encoder_layer.attention_dim - self.cnn_module_kernel = encoder_layer.cnn_module_kernel - - assert 0 <= warmup_begin <= warmup_end, (warmup_begin, warmup_end) - - delta = (1.0 / num_layers) * (warmup_end - warmup_begin) - cur_begin = warmup_begin - for i in range(num_layers): - self.layers[i].warmup_begin = cur_begin - cur_begin += delta - self.layers[i].warmup_end = cur_begin - - def get_layers_to_drop(self, rnd_seed: int): - ans = set() - if not self.training: - return ans - - batch_count = self.batch_count - num_layers = len(self.layers) - - def get_layerdrop_prob(layer: int) -> float: - layer_warmup_begin = self.layers[layer].warmup_begin - layer_warmup_end = self.layers[layer].warmup_end - - initial_layerdrop_prob = 0.5 - final_layerdrop_prob = 0.05 - - if batch_count == 0: - # As a special case, if batch_count == 0, return 0 (drop no - # layers). This is rather ugly, I'm afraid; it is intended to - # enable our scan_pessimistic_batches_for_oom() code to work correctly - # so if we are going to get OOM it will happen early. - # also search for 'batch_count' with quotes in this file to see - # how we initialize the warmup count to a random number between - # 0 and 10. - return 0.0 - elif batch_count < layer_warmup_begin: - return initial_layerdrop_prob - elif batch_count > layer_warmup_end: - return final_layerdrop_prob - else: - # linearly interpolate - t = (batch_count - layer_warmup_begin) / layer_warmup_end - assert 0.0 <= t < 1.001, t - return initial_layerdrop_prob + t * ( - final_layerdrop_prob - initial_layerdrop_prob - ) - - shared_rng = random.Random(batch_count + self.module_seed) - independent_rng = random.Random(rnd_seed) - - layerdrop_probs = [get_layerdrop_prob(i) for i in range(num_layers)] - tot = sum(layerdrop_probs) - # Instead of drawing the samples independently, we first randomly decide - # how many layers to drop out, using the same random number generator between - # jobs so that all jobs drop out the same number (this is for speed). - # Then we use an approximate approach to drop out the individual layers - # with their specified probs while reaching this exact target. - num_to_drop = int(tot) + int(shared_rng.random() < (tot - int(tot))) - - layers = list(range(num_layers)) - independent_rng.shuffle(layers) - - # go through the shuffled layers until we get the required number of samples. - if num_to_drop > 0: - for layer in itertools.cycle(layers): - if independent_rng.random() < layerdrop_probs[layer]: - ans.add(layer) - if len(ans) == num_to_drop: - break - if shared_rng.random() < 0.005 or __name__ == "__main__": - logging.info( - f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, " - f"batch_count={batch_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}" - ) - return ans - - def forward( - self, - src: Tensor, - # Note: The type of feature_mask should be Union[float, Tensor], - # but to make torch.jit.script() work, we use `float` here - feature_mask: float = 1.0, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - r"""Pass the input through the encoder layers in turn. - - Args: - src: the sequence to the encoder (required). - feature_mask: something that broadcasts with src, that we'll multiply `src` - by at every layer. - mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). - - Shape: - src: (S, N, E). - pos_emb: (N, 2*S-1, E) - mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number - - Returns: (x, x_no_combine), both of shape (S, N, E) - """ - pos_emb = self.encoder_pos(src) - output = src - - if torch.jit.is_scripting(): - layers_to_drop = [] - else: - rnd_seed = src.numel() + random.randint(0, 1000) - layers_to_drop = self.get_layers_to_drop(rnd_seed) - - output = output * feature_mask - - for i, mod in enumerate(self.layers): - if not torch.jit.is_scripting(): - if i in layers_to_drop: - continue - output = mod( - output, - pos_emb, - attn_mask=attn_mask, - src_key_padding_mask=src_key_padding_mask, - ) - - output = output * feature_mask - - return output - - @torch.jit.export - def streaming_forward( - self, - src: Tensor, - cached_len: Tensor, - cached_avg: Tensor, - cached_key: Tensor, - cached_val: Tensor, - cached_val2: Tensor, - cached_conv1: Tensor, - cached_conv2: Tensor, - ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: - r"""Pass the input through the encoder layers in turn. - - Args: - src: the sequence to the encoder (required). - cached_len: number of past frames. - cached_avg: cached average of past frames. - cached_key: cached key tensor for first attention module. - cached_val: cached value tensor for first attention module. - cached_val2: cached value tensor for second attention module. - cached_conv1: cached left contexts for the first convolution module. - cached_conv2: cached left contexts for the second convolution module. - - Shape: - src: (S, N, E). - cached_len: (num_layers,) - cached_avg: (num_layers, N, C). - N is the batch size, C is the feature dimension. - cached_key: (num_layers, left_context_len, N, K). - N is the batch size, K is the key dimension. - cached_val: (num_layers, left_context_len, N, V). - N is the batch size, V is the key dimension. - cached_val2: (num_layers, left_context_len, N, V). - N is the batch size, V is the key dimension. - cached_conv1: (num_layers, N, C, kernel_size-1). - N is the batch size, C is the convolution channels. - cached_conv2: (num_layers, N, C, kernel_size-1). - N is the batch size, C is the convolution channels. - - Returns: A tuple of 8 tensors: - - output tensor - - updated cached number of past frames. - - updated cached average of past frames. - - updated cached key tensor of of the first attention module. - - updated cached value tensor of of the first attention module. - - updated cached value tensor of of the second attention module. - - updated cached left contexts of the first convolution module. - - updated cached left contexts of the second convolution module. - """ - assert cached_len.size(0) == self.num_layers, ( - cached_len.size(0), - self.num_layers, - ) - assert cached_avg.size(0) == self.num_layers, ( - cached_avg.size(0), - self.num_layers, - ) - assert cached_key.size(0) == self.num_layers, ( - cached_key.size(0), - self.num_layers, - ) - assert cached_val.size(0) == self.num_layers, ( - cached_val.size(0), - self.num_layers, - ) - assert cached_val2.size(0) == self.num_layers, ( - cached_val2.size(0), - self.num_layers, - ) - assert cached_conv1.size(0) == self.num_layers, ( - cached_conv1.size(0), - self.num_layers, - ) - assert cached_conv2.size(0) == self.num_layers, ( - cached_conv2.size(0), - self.num_layers, - ) - - assert self.left_context_len == cached_key.shape[1], ( - self.left_context_len, - cached_key.shape[1], - ) - - left_context_len = self.left_context_len - pos_emb = self.encoder_pos(src, left_context_len) - - output = src - - new_cached_len = [] - new_cached_avg = [] - new_cached_key = [] - new_cached_val = [] - new_cached_val2 = [] - new_cached_conv1 = [] - new_cached_conv2 = [] - for i, (mod, state_select) in enumerate( - zip(self.layers, self.state_select_list) - ): - output, len_avg, avg, key, val, val2, conv1, conv2 = mod.streaming_forward( - output, - pos_emb, - cached_len=cached_len[i], - cached_avg=cached_avg[i], - cached_key=cached_key[i], - cached_val=cached_val[i], - cached_val2=cached_val2[i], - cached_conv1=state_select(cached_conv1), - cached_conv2=state_select(cached_conv2), - ) - # Update caches - new_cached_len.append(len_avg) - new_cached_avg.append(avg) - new_cached_key.append(key) - new_cached_val.append(val) - new_cached_val2.append(val2) - new_cached_conv1.append(conv1) - new_cached_conv2.append(conv2) - - return ( - output, - torch.stack(new_cached_len, dim=0), - torch.stack(new_cached_avg, dim=0), - torch.stack(new_cached_key, dim=0), - torch.stack(new_cached_val, dim=0), - torch.stack(new_cached_val2, dim=0), - torch.stack(new_cached_conv1, dim=0), - torch.stack(new_cached_conv2, dim=0), - ) - - -class DownsampledZipformerEncoder(nn.Module): - r""" - DownsampledZipformerEncoder is a zipformer encoder evaluated at a reduced frame rate, - after convolutional downsampling, and then upsampled again at the output, and combined - with the origin input, so that the output has the same shape as the input. - """ - - def __init__( - self, - encoder: nn.Module, - input_dim: int, - output_dim: int, - downsample: int, - is_pnnx: bool = False, - left_context_len: int = 0, - in_x_size: int = 0, - ): - super(DownsampledZipformerEncoder, self).__init__() - self.downsample_factor = downsample - self.downsample = AttentionDownsample( - input_dim, output_dim, downsample, is_pnnx=is_pnnx, in_x_size=in_x_size - ) - self.encoder = encoder - self.num_layers = encoder.num_layers - self.d_model = encoder.d_model - self.attention_dim = encoder.attention_dim - self.cnn_module_kernel = encoder.cnn_module_kernel - self.upsample = SimpleUpsample(output_dim, downsample) - self.out_combiner = SimpleCombiner( - input_dim, output_dim, min_weight=(0.0, 0.25) - ) - self.in_x_size = in_x_size - - def forward( - self, - src: Tensor, - # Note: the type of feature_mask should be Unino[float, Tensor], - # but to make torch.jit.script() happ, we use float here - feature_mask: float = 1.0, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - r"""Downsample, go through encoder, upsample. - - Args: - src: the sequence to the encoder (required). - feature_mask: something that broadcasts with src, that we'll multiply `src` - by at every layer. feature_mask is expected to be already downsampled by - self.downsample_factor. - attn_mask: attention mask (optional). Should be downsampled already. - src_key_padding_mask: the mask for the src keys per batch (optional). Should be downsampled already. - - Shape: - src: (S, N, E). - attn_mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number - - Returns: output of shape (S, N, F) where F is the number of output features - (output_dim to constructor) - """ - src_orig = src - src = self.downsample(src) - - src = self.encoder( - src, - feature_mask=feature_mask, - attn_mask=attn_mask, - src_key_padding_mask=src_key_padding_mask, - ) - src = self.upsample(src) - # remove any extra frames that are not a multiple of downsample_factor - src = src[: src_orig.shape[0]] - - return self.out_combiner(src_orig, src) - - def streaming_forward( - self, - src: Tensor, - cached_len: Tensor, - cached_avg: Tensor, - cached_key: Tensor, - cached_val: Tensor, - cached_val2: Tensor, - cached_conv1: Tensor, - cached_conv2: Tensor, - ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: - r"""Downsample, go through encoder, upsample. - - Args: - src: the sequence to the encoder (required). - cached_avg: cached average value of past frames. - cached_len: length of past frames. - cached_key: cached key tensor for the first attention module. - cached_val: cached value tensor for the first attention module. - cached_val2: cached value tensor for the second attention module. - cached_conv1: cached left context for the first convolution module. - cached_conv2: cached left context for the second convolution module. - - Shape: - src: (S, N, E). - cached_len: (N,) - N is the batch size. - cached_avg: (num_layers, N, C). - N is the batch size, C is the feature dimension. - cached_key: (num_layers, left_context_len, N, K). - N is the batch size, K is the key dimension. - cached_val: (num_layers, left_context_len, N, V). - N is the batch size, V is the key dimension. - cached_val2: (num_layers, left_context_len, N, V). - N is the batch size, V is the key dimension. - cached_conv1: (num_layers, N, C, kernel_size-1). - N is the batch size, C is the convolution channels. - cached_conv2: (num_layers, N, C, kernel_size-1). - N is the batch size, C is the convolution channels. - Returns: output of shape (S, N, F) where F is the number of output features - (output_dim to constructor) - """ - assert src.shape[0] == self.in_x_size, (src.shape[0], self.in_x_size) - - src_orig = src - - src = self.downsample(src) - - ( - src, - cached_len, - cached_avg, - cached_key, - cached_val, - cached_val2, - cached_conv1, - cached_conv2, - ) = self.encoder.streaming_forward( - src, - cached_len=cached_len, - cached_avg=cached_avg, - cached_key=cached_key, - cached_val=cached_val, - cached_val2=cached_val2, - cached_conv1=cached_conv1, - cached_conv2=cached_conv2, - ) - - src = self.upsample(src) - - if src.shape[0] != self.in_x_size: - # remove any extra frames that are not a multiple of downsample_factor - src = src[: self.in_x_size] - - return ( - self.out_combiner(src_orig, src), - cached_len, - cached_avg, - cached_key, - cached_val, - cached_val2, - cached_conv1, - cached_conv2, - ) - - -class AttentionDownsampleUnsqueeze(torch.nn.Module): - """We apply this operation only in PyTorch - and discards in ncnn. - """ - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return x.unsqueeze(1) - - -class AttentionDownsample(torch.nn.Module): - """ - Does downsampling with attention, by weighted sum, and a projection.. - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - downsample: int, - is_pnnx: bool = False, - in_x_size: int = 0, - ): - super(AttentionDownsample, self).__init__() - - self.query = nn.Parameter(torch.randn(in_channels) * (in_channels**-0.5)) - - self.in_channels = in_channels - self.out_channels = out_channels - self.is_pnnx = is_pnnx - self.in_x_size = in_x_size - - self.unsqueeze = AttentionDownsampleUnsqueeze() - - # fill in the extra dimensions with a projection of the input - if out_channels > in_channels: - self.extra_proj = nn.Linear( - in_channels * downsample, out_channels - in_channels, bias=False - ) - else: - self.extra_proj = None - self.downsample = downsample - - self.d_seq_len = (in_x_size + downsample - 1) // downsample - - def forward(self, src: Tensor) -> Tensor: - """ - x: (seq_len, 1, in_channels) - Returns a tensor of shape - ( (seq_len+downsample-1)//downsample, batch_size, out_channels) - """ - assert src.shape[0] == self.in_x_size, ( - src.shape[0], - self.in_x_size, - src.shape, - type(src), - ) - assert src.shape[2] == self.in_channels, (src.shape[2], self.in_channels) - if not self.is_pnnx: - (seq_len, batch_size, in_channels) = src.shape - else: - seq_len = self.in_x_size - batch_size = 1 - in_channels = self.in_channels - - ds = self.downsample - d_seq_len = self.d_seq_len - - # Pad to an exact multiple of self.downsample - if seq_len != d_seq_len * ds: - assert self.is_pnnx is False, "TODO(fangjun): Handle it!" - # right-pad src, repeating the last element. - pad = d_seq_len * ds - seq_len - src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) - src = torch.cat((src, src_extra), dim=0) - assert src.shape[0] == d_seq_len * ds, (src.shape[0], d_seq_len, ds) - - if not self.is_pnnx: - src = src.reshape(d_seq_len, ds, batch_size, in_channels) - scores = (src * self.query).sum(dim=-1, keepdim=True) - - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04) - - weights = scores.softmax(dim=1) - - # ans1 is the first `in_channels` channels of the output - ans = (src * weights).sum(dim=1) - src = src.permute(0, 2, 1, 3).reshape( - d_seq_len, batch_size, ds * in_channels - ) - - if self.extra_proj is not None: - ans2 = self.extra_proj(src) - ans = torch.cat((ans, ans2), dim=2) - else: - src = src.reshape(d_seq_len, ds, in_channels) - scores = (src * self.query).sum(dim=-1, keepdim=True) - - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04) - - weights = scores.softmax(dim=1) - - # ans1 is the first `in_channels` channels of the output - ans = (src * weights).sum(dim=1) - - assert ( - self.extra_proj is None - ), "The code for it being not None is not tested" - # ans = ans.unsqueeze(1) - ans = self.unsqueeze(ans) - # Note: In ncnn, we ignore self.unsqueeze - # so ans in ncnn is still a 2-D tensor, e.g., (8, 384) - - return ans - - -class SimpleUpsample(torch.nn.Module): - """ - A very simple form of upsampling that mostly just repeats the input, but - also adds a position-specific bias. - """ - - def __init__(self, num_channels: int, upsample: int): - super(SimpleUpsample, self).__init__() - self.bias = nn.Parameter(torch.randn(upsample, num_channels) * 0.01) - self.upsample = upsample - self.num_channels = num_channels - - def forward(self, src: Tensor) -> Tensor: - """ - x: (seq_len, batch_size, num_channels) - Returns a tensor of shape - ( (seq_len*upsample), batch_size, num_channels) - """ - upsample = self.bias.shape[0] - (seq_len, batch_size, num_channels) = src.shape - src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels) - src = src + self.bias.unsqueeze(1) - src = src.reshape(seq_len * upsample, batch_size, num_channels) - return src - - -class SimpleCombinerIdentity(nn.Module): - def __init__(self, *args, **kwargs): - super().__init__() - - def forward(self, src1: Tensor, src2: Tensor) -> Tensor: - return src1 - - -class SimpleCombiner(torch.nn.Module): - """ - A very simple way of combining 2 vectors of 2 different dims, via a - learned weighted combination in the shared part of the dim. - Args: - dim1: the dimension of the first input, e.g. 256 - dim2: the dimension of the second input, e.g. 384. - The output will have the same dimension as dim2. - """ - - def __init__(self, dim1: int, dim2: int, min_weight: Tuple[float] = (0.0, 0.0)): - super(SimpleCombiner, self).__init__() - assert dim2 >= dim1, (dim2, dim1) - self.weight1 = nn.Parameter(torch.zeros(())) - self.min_weight = min_weight - self.dim1 = dim1 - self.dim2 = dim2 - - def forward(self, src1: Tensor, src2: Tensor) -> Tensor: - """ - src1: (*, dim1) - src2: (*, dim2) - - Returns: a tensor of shape (*, dim2) - """ - assert src1.shape[:-1] == src2.shape[:-1], (src1.shape, src2.shape) - - weight1 = self.weight1 - if not torch.jit.is_scripting(): - if ( - self.training - and random.random() < 0.25 - and self.min_weight != (0.0, 0.0) - ): - weight1 = weight1.clamp( - min=self.min_weight[0], max=1.0 - self.min_weight[1] - ) - - src1 = src1 * weight1 - src2 = src2 * (1.0 - weight1) - - assert src1.shape[-1] == self.dim1, (src1.shape[-1], self.dim1) - assert src2.shape[-1] == self.dim2, (src2.shape[-1], self.dim2) - - src1_dim = self.dim1 - src2_dim = self.dim2 - - if src1_dim != src2_dim: - if src1_dim < src2_dim: - src1 = torch.nn.functional.pad(src1, (0, src2_dim - src1_dim)) - else: - src1 = src1[:src2_dim] - - return src1 + src2 - - -class RelPositionalEncoding(torch.nn.Module): - """Relative positional encoding module. - - See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py - - Args: - d_model: Embedding dimension. - dropout_rate: Dropout rate. - max_len: Maximum input length. - - """ - - def __init__( - self, - d_model: int, - dropout_rate: float, - max_len: int = 5000, - is_pnnx: bool = False, - x_size: int = 0, - left_context_len: int = 0, - ) -> None: - """Construct a PositionalEncoding object.""" - super(RelPositionalEncoding, self).__init__() - self.d_model = d_model - self.dropout = torch.nn.Dropout(dropout_rate) - self.is_pnnx = is_pnnx - self.x_size = x_size - self.left_context_len = left_context_len - self.pe = None - if is_pnnx: - x_size_left = x_size + left_context_len - self.extend_pe(torch.tensor(0.0).expand(x_size_left)) - self.pe = self.pe[:, :-left_context_len] - assert self.pe.size(1) == x_size + left_context_len - 1 + x_size, ( - self.pe.size(1), - x_size, - left_context_len, - x_size, - self.pe.shape, - ) - else: - self.extend_pe(torch.tensor(0.0).expand(max_len)) - - def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None: - """Reset the positional encodings.""" - x_size_left = x.size(0) + left_context_len - if self.pe is not None: - # self.pe contains both positive and negative parts - # the length of self.pe is 2 * input_len - 1 - if self.pe.size(1) >= x_size_left * 2 - 1: - # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - # Suppose `i` means to the position of query vector and `j` means the - # position of key vector. We use positive relative positions when keys - # are to the left (i>j) and negative relative positions otherwise (i Tensor: - """Add positional encoding. - - Args: - x (torch.Tensor): Input tensor (time, batch, `*`). - left_context_len: (int): Length of cached left context. - - Returns: - torch.Tensor: Encoded tensor (batch, left_context_len + 2*time-1, `*`). - - """ - if self.is_pnnx: - assert self.x_size == x.size(0), (self.x_size, x.size(0)) - assert self.left_context_len == left_context_len, ( - self.left_context_len, - left_context_len, - ) - return self.pe - - self.extend_pe(x, left_context_len) - x_size_left = x.size(0) + left_context_len - pos_emb = self.pe[ - :, - self.pe.size(1) // 2 - - x_size_left - + 1 : self.pe.size(1) // 2 # noqa E203 - + x.size(0), - ] - return self.dropout(pos_emb) - - -class RelPositionMultiheadAttentionPermute(nn.Module): - """ncnn does not support permuatation relating to the batch axis 0. - This is a workaround for exporting to ncnn via PNNX. - """ - - def __init__(self, kind: int): - super().__init__() - self.kind = kind - assert self.kind in (2, 3), self.kind - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.kind == 2: - return x.permute(1, 0, 2) - elif self.kind == 3: - return x.permute(1, 2, 0) - else: - assert False, f"Unsupported kind {self.kind}" - - -class RelPositionMultiheadAttention(nn.Module): - r"""Multi-Head Attention layer with relative position encoding - - This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", - we have to write up the differences. - - - Args: - embed_dim: total dimension of the model. - attention_dim: dimension in the attention module, may be less or more than embed_dim - but must be a multiple of num_heads. - num_heads: parallel attention heads. - dropout: a Dropout layer on attn_output_weights. Default: 0.0. - - Examples:: - - >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) - >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) - """ - - def __init__( - self, - embed_dim: int, - attention_dim: int, - num_heads: int, - pos_dim: int, - dropout: float = 0.0, - is_pnnx: bool = False, - left_context_len: int = 0, - x_size: int = 0, - ) -> None: - super(RelPositionMultiheadAttention, self).__init__() - self.embed_dim = embed_dim - self.attention_dim = attention_dim - self.num_heads = num_heads - self.dropout = dropout - self.head_dim = attention_dim // num_heads - self.pos_dim = pos_dim - assert self.head_dim % 2 == 0, self.head_dim - assert self.head_dim * num_heads == attention_dim, ( - self.head_dim, - num_heads, - attention_dim, - ) - - self.is_pnnx = is_pnnx - - self.my_permute_pqv = RelPositionMultiheadAttentionPermute(kind=2) - self.my_permute_k_pos = RelPositionMultiheadAttentionPermute(kind=3) - self.left_context_len = left_context_len - self.x_size = x_size - - # the initial_scale is supposed to take over the "scaling" factor of - # head_dim ** -0.5, dividing it between the query and key. - in_proj_dim = ( - 2 * attention_dim - + attention_dim // 2 # query (attention_dim,), key (attention_dim,) - + pos_dim * num_heads # value (attention_dim // 2,) - ) # positional encoding query (pos_dim * num_heads, ) - - self.in_proj = ScaledLinear( - embed_dim, in_proj_dim, bias=True, initial_scale=self.head_dim**-0.25 - ) - - # self.whiten_values is applied on the values in forward(); - # it just copies the keys but prevents low-rank distribution by modifying grads. - self.whiten_values = Whiten( - num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025, - ) - self.whiten_keys = Whiten( - num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025, - ) - - # linear transformation for positional encoding. - self.linear_pos = ScaledLinear( - embed_dim, num_heads * pos_dim, bias=False, initial_scale=0.05 - ) - - # the following are for diagnosics only, see --print-diagnostics option. - # they only copy their inputs. - self.copy_pos_query = Identity() - self.copy_query = Identity() - - self.out_proj = ScaledLinear( - attention_dim // 2, embed_dim, bias=True, initial_scale=0.05 - ) - - self.in_proj2 = nn.Linear(embed_dim, attention_dim // 2, bias=False) - self.out_proj2 = ScaledLinear( - attention_dim // 2, embed_dim, bias=True, initial_scale=0.05 - ) - # self.whiten_values2 is applied on the values in forward2() - self.whiten_values2 = Whiten( - num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025, - ) - - def forward( - self, - x: Tensor, - pos_emb: Tensor, - key_padding_mask: Optional[Tensor] = None, - attn_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor]: - r""" - Args: - x: input to be projected to query, key, value - pos_emb: Positional embedding tensor - key_padding_mask: if provided, specified padding elements in the key will - be ignored by the attention. When given a binary mask and a value is True, - the corresponding value on the attention layer will be ignored. When given - a byte mask and a value is non-zero, the corresponding value on the attention - layer will be ignored - attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all - the batches while a 3D mask allows to specify a different mask for the entries of each batch. - - Shape: - - Inputs: - - x: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. - If a ByteTensor is provided, the non-zero positions will be ignored while the position - with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the - value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, - S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked - positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend - while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` - is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor - is provided, it will be added to the attention weight. - - - Returns: (attn_output, attn_weights) - - - attn_output: :math:`(S, N, E)` where S is the sequence length, N is the batch size, - E is the embedding dimension. - - attn_weights: :math:`(N * N, S, S)` where N is the batch size, H is the num-heads - and S is the sequence length. - """ - x, weights = self.multi_head_attention_forward( - self.in_proj(x), - self.linear_pos(pos_emb), - self.attention_dim, - self.num_heads, - self.dropout, - self.out_proj.weight, - self.out_proj.bias, - training=self.training, - key_padding_mask=key_padding_mask, - attn_mask=attn_mask, - ) - return x, weights - - def streaming_forward( - self, - x: Tensor, - pos_emb: Tensor, - cached_key: Tensor, - cached_val: Tensor, - ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - r""" - Args: - x: input to be projected to query, key, value - pos_emb: Positional embedding tensor - - Shape: - - Inputs: - - x: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - cached_key: :math:`(left_context_len, N, K)`, where N is the batch size, K is the key dimension. - - cached_val: :math:`(left_context_len, N, V)`, where N is the batch size, V is the value dimension. - - - Returns: (attn_output, attn_weights, cached_key, cached_val) - - - attn_output: :math:`(S, N, E)` where S is the sequence length, N is the batch size, - E is the embedding dimension. - - attn_weights: :math:`(N * N, S, S)` where N is the batch size, H is the num-heads - and S is the sequence length. - - cached_key: :math:`(left_context_len, N, K)`, updated cached attention key tensor of - left context - - cached_val: :math:`(left_context_len, N, K)`, updated cached attention value tensor of - """ - ( - x, - weights, - cached_key, - cached_val, - ) = self.streaming_multi_head_attention_forward( - self.in_proj(x), - self.linear_pos(pos_emb), - self.attention_dim, - self.num_heads, - self.out_proj.weight, - self.out_proj.bias, - cached_key=cached_key, - cached_val=cached_val, - ) - return x, weights, cached_key, cached_val - - def multi_head_attention_forward( - self, - x_proj: Tensor, - pos: Tensor, - attention_dim: int, - num_heads: int, - dropout_p: float, - out_proj_weight: Tensor, - out_proj_bias: Tensor, - training: bool = True, - key_padding_mask: Optional[Tensor] = None, - attn_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor]: - r""" - Args: - x_proj: the projected input, to be split into query, key, value. - pos: head-specific biases arising from the positional embeddings. - attention_dim: dimension inside attention mechanism - num_heads: parallel attention heads. - dropout_p: probability of an element to be zeroed. - out_proj_weight, out_proj_bias: the output projection weight and bias. - training: apply dropout if is ``True``. - key_padding_mask: if provided, specified padding elements in the key will - be ignored by the attention. This is an binary mask. When the value is True, - the corresponding value on the attention layer will be filled with -inf. - attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all - the batches while a 3D mask allows to specify a different mask for the entries of each batch. - - Shape: - Inputs: - - x: :math:`(L, N, 7 * A // 2)` where L is the target sequence length, N is the batch size, A is - the attention dimension. Will be split into (query, key, value, pos). - - pos: :math:`(N, 2*L-1, A//2)` or :math:`(1, 2*L-1, A//2)` where L is the sequence - length, N is the batch size, and A is the attention dim. - - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. - If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions - will be unchanged. If a BoolTensor is provided, the positions with the - value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, - S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked - positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend - while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` - are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor - is provided, it will be added to the attention weight. - - Outputs: - - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, - E is the embedding dimension. - - attn_weights: :math:`(N * H, S, S)` where N is the batch size, - H is the num-heads, S is the sequence length. - """ - - seq_len, bsz, _ = x_proj.size() - - head_dim = attention_dim // num_heads - pos_dim = self.pos_dim # positional-encoding dim per head - assert ( - head_dim * num_heads == attention_dim - ), f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}, {attention_dim}" - - # self-attention - q = x_proj[..., 0:attention_dim] - k = x_proj[..., attention_dim : 2 * attention_dim] - value_dim = attention_dim // 2 - v = x_proj[..., 2 * attention_dim : 2 * attention_dim + value_dim] - # p is the position-encoding query, its dimension is num_heads*pos_dim.. - p = x_proj[..., 2 * attention_dim + value_dim :] - - k = self.whiten_keys(k) # does nothing in the forward pass. - v = self.whiten_values(v) # does nothing in the forward pass. - q = self.copy_query(q) # for diagnostics only, does nothing. - p = self.copy_pos_query(p) # for diagnostics only, does nothing. - - if attn_mask is not None: - assert ( - attn_mask.dtype == torch.float32 - or attn_mask.dtype == torch.float64 - or attn_mask.dtype == torch.float16 - or attn_mask.dtype == torch.uint8 - or attn_mask.dtype == torch.bool - ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( - attn_mask.dtype - ) - if attn_mask.dtype == torch.uint8: - warnings.warn( - "Byte tensor for attn_mask is deprecated. Use bool tensor instead." - ) - attn_mask = attn_mask.to(torch.bool) - - if attn_mask.dim() == 2: - attn_mask = attn_mask.unsqueeze(0) - if list(attn_mask.size()) != [1, seq_len, seq_len]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") - elif attn_mask.dim() == 3: - if list(attn_mask.size()) != [ - bsz * num_heads, - seq_len, - seq_len, - ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") - else: - raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) - ) - # attn_mask's dim is 3 now. - - # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: - warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." - ) - key_padding_mask = key_padding_mask.to(torch.bool) - - q = q.reshape(seq_len, bsz, num_heads, head_dim) - p = p.reshape(seq_len, bsz, num_heads, pos_dim) - k = k.reshape(seq_len, bsz, num_heads, head_dim) - v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1) - - if key_padding_mask is not None: - assert key_padding_mask.size(0) == bsz, "{} == {}".format( - key_padding_mask.size(0), bsz - ) - assert key_padding_mask.size(1) == seq_len, "{} == {}".format( - key_padding_mask.size(1), seq_len - ) - - q = q.permute(1, 2, 0, 3) # (batch, head, time1, head_dim) - p = p.permute(1, 2, 0, 3) # (batch, head, time1, pos_dim) - k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - - seq_len2 = 2 * seq_len - 1 - pos = pos.reshape(1, seq_len2, num_heads, pos_dim).permute(0, 2, 3, 1) - # pos shape now: (batch, head, pos_dim, seq_len2) - - # (batch, head, time1, pos_dim) x (1, head, pos_dim, seq_len2) -> (batch, head, time1, seq_len2) - # [where seq_len2 represents relative position.] - pos_weights = torch.matmul(p, pos) - # the following .as_strided() expression converts the last axis of pos_weights from relative - # to absolute position. I don't know whether I might have got the time-offsets backwards or - # not, but let this code define which way round it is supposed to be. - pos_weights = pos_weights.as_strided( - (bsz, num_heads, seq_len, seq_len), - ( - pos_weights.stride(0), - pos_weights.stride(1), - pos_weights.stride(2) - pos_weights.stride(3), - pos_weights.stride(3), - ), - storage_offset=pos_weights.stride(3) * (seq_len - 1), - ) - - # caution: they are really scores at this point. - attn_output_weights = torch.matmul(q, k) + pos_weights - - if not torch.jit.is_scripting(): - if training and random.random() < 0.1: - # This is a harder way of limiting the attention scores to not be too large. - # It incurs a penalty if any of them has an absolute value greater than 50.0. - # this should be outside the normal range of the attention scores. We use - # this mechanism instead of, say, a limit on entropy, because once the entropy - # gets very small gradients through the softmax can become very small, and - # some mechanisms like that become ineffective. - attn_output_weights = penalize_abs_values_gt( - attn_output_weights, limit=25.0, penalty=1.0e-04 - ) - - # attn_output_weights: (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, seq_len, seq_len - ) - - if attn_mask is not None: - if attn_mask.dtype == torch.bool: - attn_output_weights = attn_output_weights.masked_fill( - attn_mask, float("-inf") - ) - else: - attn_output_weights = attn_output_weights + attn_mask - - if key_padding_mask is not None: - attn_output_weights = attn_output_weights.view( - bsz, num_heads, seq_len, seq_len - ) - attn_output_weights = attn_output_weights.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2), - float("-inf"), - ) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, seq_len, seq_len - ) - - # Using this version of softmax, defined in scaling.py, - # should save a little of the memory used in backprop by, if - # we are in automatic mixed precision mode (amp) == autocast, - # only storing the half-precision output for backprop purposes. - attn_output_weights = softmax(attn_output_weights, dim=-1) - - # If we are using chunk-wise attention mask and setting a limited - # num_left_chunks, the attention may only see the padding values which - # will also be masked out by `key_padding_mask`. At this circumstances, - # the whole column of `attn_output_weights` will be `-inf` - # (i.e. be `nan` after softmax). So we fill `0.0` at the masking - # positions to avoid invalid loss value below. - if ( - attn_mask is not None - and attn_mask.dtype == torch.bool - and key_padding_mask is not None - ): - if attn_mask.size(0) != 1: - attn_mask = attn_mask.view(bsz, num_heads, seq_len, seq_len) - combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) - else: - # attn_mask.shape == (1, tgt_len, src_len) - combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( - 1 - ).unsqueeze(2) - - attn_output_weights = attn_output_weights.view( - bsz, num_heads, seq_len, seq_len - ) - attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, seq_len, seq_len - ) - - attn_output_weights = nn.functional.dropout( - attn_output_weights, p=dropout_p, training=training - ) - - attn_output = torch.bmm(attn_output_weights, v) - assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim // 2] - attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(seq_len, bsz, attention_dim // 2) - ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) - - return attn_output, attn_output_weights - - def streaming_multi_head_attention_forward( - self, - x_proj: Tensor, - pos: Tensor, - attention_dim: int, - num_heads: int, - out_proj_weight: Tensor, - out_proj_bias: Tensor, - cached_key: Tensor, - cached_val: Tensor, - ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - r""" - Args: - x_proj: the projected input, to be split into query, key, value. - pos: head-specific biases arising from the positional embeddings. - attention_dim: dimension inside attention mechanism - num_heads: parallel attention heads. - out_proj_weight, out_proj_bias: the output projection weight and bias. - cached_key: cached attention key tensor of left context. - cached_val: cached attention value tensor of left context. - - Shape: - Inputs: - - x: :math:`(L, N, 7 * A // 2)` where L is the target sequence length, N is the batch size, A is - the attention dimension. Will be split into (query, key, value, pos). - - pos: :math:`(N, 2*L-1, A//2)` or :math:`(1, 2*L-1, A//2)` where L is the sequence - length, N is the batch size, and A is the attention dim. - If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions - will be unchanged. If a BoolTensor is provided, the positions with the - value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - Outputs: - - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, - E is the embedding dimension. - - attn_weights: :math:`(N * H, S, S)` where N is the batch size, - H is the num-heads, S is the sequence length. - - cached_key: :math:`(left_context_len, N, K)`, updated cached attention key tensor of left context. - - cached_val: :math:`(left_context_len, N, K)`, updated cached attention value tensor of left context. - """ - if not self.is_pnnx: - seq_len, bsz, _ = x_proj.size() - assert seq_len == self.x_size, (seq_len, self.x_size) - else: - seq_len = self.x_size - bsz = 1 - - head_dim = attention_dim // num_heads - pos_dim = self.pos_dim # positional-encoding dim per head - assert ( - head_dim * num_heads == attention_dim - ), f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}, {attention_dim}" - - # self-attention - q = x_proj[:, :, 0:attention_dim] # (x_size, N, attention_dim) - # return q, q, q, q - k = x_proj[:, :, attention_dim : 2 * attention_dim] - # k is (x_size, N, attention_dim) - value_dim = attention_dim // 2 - v = x_proj[:, :, 2 * attention_dim : 2 * attention_dim + value_dim] - # v is (x_size, 0, attention_dim//2) - - # p is the position-encoding query, its dimension is num_heads*pos_dim.. - p = x_proj[:, :, 2 * attention_dim + value_dim :] - # p is (x_size, N, pos_dim * num_heads) - - if not self.is_pnnx: - left_context_len = cached_key.shape[0] - else: - assert cached_key.shape[0] == self.left_context_len, ( - cached_key.shape, - self.left_context_len, - ) - left_context_len = self.left_context_len - - assert left_context_len > 0, left_context_len - assert cached_key.shape[0] == cached_val.shape[0], ( - cached_key.shape, - cached_val.shape, - ) - # Note: We need to fix the Concat in ncnn - # cached_key is (1, 64, 192) in ncnn - # k is (16, 192) in ncnn - # Pad cached left contexts - k = torch.cat([cached_key, k], dim=0) - # (left_context_len + x_size, N, attention_dim) - - v = torch.cat([cached_val, v], dim=0) - # v: (left_context_len + x_size, N, attention_dim//2) - # Update cached left contexts - if not self.is_pnnx: - cached_key = k[-left_context_len:, ...] - cached_val = v[-left_context_len:, ...] - else: - cached_key = k[self.x_size :] - cached_val = v[self.x_size :] - assert cached_key.shape[0] == left_context_len, ( - cached_key.shape, - left_context_len, - ) - assert cached_val.shape[0] == left_context_len, ( - cached_val.shape, - left_context_len, - ) - - if not self.is_pnnx: - # The length of key and value - kv_len = k.shape[0] - else: - kv_len = left_context_len + self.x_size - assert kv_len == k.shape[0], (kv_len, k.shape) - - if not self.is_pnnx: - q = q.reshape(seq_len, bsz, num_heads, head_dim) - p = p.reshape(seq_len, bsz, num_heads, pos_dim) - k = k.reshape(kv_len, bsz, num_heads, head_dim) - - v = v.reshape(kv_len, bsz * num_heads, head_dim // 2).transpose(0, 1) - # v is (bsz * num_heads, kv_len, head_dim//2) - - q = q.permute(1, 2, 0, 3) # (batch, head, time1, head_dim) - p = p.permute(1, 2, 0, 3) # (batch, head, time1, pos_dim) - k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - - seq_len2 = 2 * seq_len - 1 + left_context_len - pos = pos.reshape(1, seq_len2, num_heads, pos_dim).permute(0, 2, 3, 1) - # pos shape now: (batch, head, pos_dim, seq_len2) - else: - q = q.reshape(seq_len, num_heads, head_dim) - p = p.reshape(seq_len, num_heads, pos_dim) - k = k.reshape(kv_len, num_heads, head_dim) - # v = v.reshape(kv_len, num_heads, head_dim // 2).permute(1, 0, 2) - v = v.reshape(kv_len, num_heads, head_dim // 2) - v = self.my_permute_pqv(v) - # v is (num_heads, kv_len, head_dim//2) e.g., (8, 80, 12) - - # q = q.permute(1, 0, 2) # (head, time1, head_dim) - # p = p.permute(1, 0, 2) # (head, time1, pos_dim) - # k = k.permute(1, 2, 0) # (head, d_k, time2) - - q = self.my_permute_pqv(q) # (head, time1, head_dim), e.g., (8, 16, 24) - p = self.my_permute_pqv(p) # (head, time1, pos_dim), e.g., (8, 16, 4) - k = self.my_permute_k_pos(k) # (head, d_k, time2) e.g., (8, 24, 80) - - seq_len2 = 2 * seq_len - 1 + left_context_len - # pos = pos.reshape(seq_len2, num_heads, pos_dim).permute(1, 2, 0) - # pos shape now: (head, pos_dim, seq_len2) - - pos = pos.reshape(seq_len2, num_heads, pos_dim) - pos = self.my_permute_k_pos( - pos - ) # (head, pos_dim, seq_len2), e.g, (8, 4, 95) - - # (batch, head, time1, pos_dim) x (1, head, pos_dim, seq_len2) -> (batch, head, time1, seq_len2) ,e.g., (1, 8, 16, 95) - # [where seq_len2 represents relative position.] - pos_weights = torch.matmul(p, pos) - - # the following .as_strided() expression converts the last axis of pos_weights from relative - # to absolute position. I don't know whether I might have got the time-offsets backwards or - # not, but let this code define which way round it is supposed to be. - - if not self.is_pnnx: - pos_weights = pos_weights.as_strided( - (bsz, num_heads, seq_len, kv_len), - ( - pos_weights.stride(0), - pos_weights.stride(1), - pos_weights.stride(2) - pos_weights.stride(3), - pos_weights.stride(3), - ), - storage_offset=pos_weights.stride(3) * (seq_len - 1), - ) - else: - pos_weights = pos_weights.as_strided( - (num_heads, seq_len, kv_len), - ( - pos_weights.stride(0), - pos_weights.stride(1) - pos_weights.stride(2), - pos_weights.stride(2), - ), - storage_offset=pos_weights.stride(2) * (seq_len - 1), - ) - - # caution: they are really scores at this point. - attn_output_weights = torch.matmul(q, k) + pos_weights - - # attn_output_weights: (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, seq_len, kv_len) - - # Using this version of softmax, defined in scaling.py, - # should save a little of the memory used in backprop by, if - # we are in automatic mixed precision mode (amp) == autocast, - # only storing the half-precision output for backprop purposes. - attn_output_weights = softmax(attn_output_weights, dim=-1) - - attn_output = torch.bmm(attn_output_weights, v) - - assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim // 2] - # (8, 16, 12) - - if not self.is_pnnx: - attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(seq_len, bsz, attention_dim // 2) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias - ) - else: - attn_output = self.my_permute_pqv(attn_output) # (1, 0, 2) - attn_output = attn_output.reshape(seq_len, bsz, attention_dim // 2) - # We have changed InnerProduct in ncnn to treat - # (seq_len, bsz, attention_dim//2) as - # (seq_len, attention_dim//2) - - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias - ) - - return attn_output, attn_output_weights, cached_key, cached_val - - def forward2( - self, - x: Tensor, - attn_weights: Tensor, - ) -> Tensor: - """ - Second forward function, where we re-use the attn_weights returned by the first forward function - but with different input. - Args: - x: input, of shape (seq_len, batch_size, embed_dim) - attn_weights: attention weights returned by forward(), of shape (batch_size * num_heads, seq_len, seq_len) - Returns: - output of the same shape as x, i.e. (seq_len, batch_size, embed_dim) - """ - num_heads = self.num_heads - (seq_len, bsz, embed_dim) = x.shape - head_dim = self.attention_dim // num_heads - # v: (tgt_len, bsz, embed_dim // 2) - v = self.in_proj2(x) - v = self.whiten_values2(v) # does nothing in the forward pass. - v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1) - - # now v: (bsz * num_heads, seq_len, head_dim // 2) - attn_output = torch.bmm(attn_weights, v) - - if not torch.jit.is_scripting(): - if random.random() < 0.001 or __name__ == "__main__": - self._print_attn_stats(attn_weights, attn_output) - - # attn_output: (bsz * num_heads, seq_len, head_dim) - attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(seq_len, bsz, self.attention_dim // 2) - ) - # returned value is of shape (seq_len, bsz, embed_dim), like x. - return self.out_proj2(attn_output) - - def streaming_forward2( - self, - x: Tensor, - attn_weights: Tensor, - cached_val: Tensor, - ) -> Tuple[Tensor, Tensor]: - """ - Second forward function, where we re-use the attn_weights returned by the first forward function - but with different input. - Args: - x: input, of shape (seq_len, batch_size, embed_dim) - attn_weights: attention weights returned by forward(), of shape (batch_size * num_heads, seq_len, seq_len) - cached_val: cached attention value tensor of left context. - Returns: - - output of the same shape as x, i.e. (seq_len, batch_size, embed_dim) - - updated cached attention value tensor of left context. - """ - num_heads = self.num_heads - - assert x.shape[0] == self.x_size, (x.shape[0], self.x_size) - assert x.shape[2] == self.embed_dim, (x.shape[2], self.embed_dim) - - if not self.is_pnnx: - (seq_len, bsz, embed_dim) = x.shape - else: - seq_len = self.x_size - bsz = 1 - embed_dim = self.embed_dim - - head_dim = self.attention_dim // num_heads - # v: (tgt_len, bsz, embed_dim // 2) - v = self.in_proj2(x) - - assert cached_val.shape[0] == self.left_context_len, ( - cached_val.shape[0], - self.left_context_len, - ) - - left_context_len = self.left_context_len - assert left_context_len > 0, left_context_len - v = torch.cat([cached_val, v], dim=0) - cached_val = v[-left_context_len:] - - seq_len2 = left_context_len + seq_len - if not self.is_pnnx: - v = v.reshape(seq_len2, bsz * num_heads, head_dim // 2).transpose(0, 1) - else: - v = v.reshape(seq_len2, bsz * num_heads, head_dim // 2) - # v = v.permute(1, 0, 2) - v = self.my_permute_pqv(v) - - # now v: (bsz * num_heads, seq_len, head_dim // 2) - attn_output = torch.bmm(attn_weights, v) - - if not self.is_pnnx: - # attn_output: (bsz * num_heads, seq_len, head_dim) - attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(seq_len, bsz, self.attention_dim // 2) - ) - else: - attn_output = self.my_permute_pqv(attn_output) # (1, 0, 2) - attn_output = attn_output.reshape(seq_len, bsz, self.attention_dim // 2) - # We have changed InnerProduct in ncnn to ignore bsz - # when invoking self.out_proj2(attn_output) - - # returned value is of shape (seq_len, bsz, embed_dim), like x. - return self.out_proj2(attn_output), cached_val - - def _print_attn_stats(self, attn_weights: Tensor, attn_output: Tensor): - # attn_weights: (batch_size * num_heads, seq_len, seq_len) - # attn_output: (bsz * num_heads, seq_len, head_dim) - (n, seq_len, head_dim) = attn_output.shape - num_heads = self.num_heads - bsz = n // num_heads - - with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): - attn_weights = attn_weights.to(torch.float32) - attn_output = attn_output.to(torch.float32) - attn_weights_entropy = ( - -((attn_weights + 1.0e-20).log() * attn_weights) - .sum(dim=-1) - .reshape(bsz, num_heads, seq_len) - .mean(dim=(0, 2)) - ) - attn_output = attn_output.reshape(bsz, num_heads, seq_len, head_dim) - attn_output = attn_output.permute(1, 0, 2, 3).reshape( - num_heads, bsz * seq_len, head_dim - ) - attn_output_mean = attn_output.mean(dim=1, keepdim=True) - attn_output = attn_output - attn_output_mean - attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / ( - bsz * seq_len - ) - # attn_covar: (num_heads, head_dim, head_dim) - # eigs, _ = torch.symeig(attn_covar) - # logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") - - attn_covar = _diag(attn_covar).mean(dim=1) # (num_heads,) - embed_dim = self.in_proj2.weight.shape[1] - in_proj_covar = ( - self.in_proj2.weight.reshape(num_heads, head_dim, embed_dim) ** 2 - ).mean(dim=(1, 2)) - out_proj_covar = ( - self.out_proj2.weight.reshape(embed_dim, num_heads, head_dim) ** 2 - ).mean(dim=(0, 2)) - logging.info( - f"attn_weights_entropy = {attn_weights_entropy}, covar={attn_covar}, in_proj_covar={in_proj_covar}, out_proj_covar={out_proj_covar}" - ) - - -class FeedforwardModule(nn.Module): - """Feedforward module in Zipformer model.""" - - def __init__(self, d_model: int, feedforward_dim: int, dropout: float): - super(FeedforwardModule, self).__init__() - self.in_proj = nn.Linear(d_model, feedforward_dim) - self.balancer = ActivationBalancer( - feedforward_dim, channel_dim=-1, max_abs=10.0, min_prob=0.25 - ) - self.activation = DoubleSwish() - self.dropout = nn.Dropout(dropout) - self.out_proj = ScaledLinear(feedforward_dim, d_model, initial_scale=0.01) - - def forward(self, x: Tensor): - x = self.in_proj(x) - x = self.balancer(x) - x = self.activation(x) - x = self.dropout(x) - x = self.out_proj(x) - return x - - -class ConvolutionModule(nn.Module): - """ConvolutionModule in Zipformer model. - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py - - Args: - channels (int): The number of channels of conv layers. - kernel_size (int): Kernerl size of conv layers. - bias (bool): Whether to use bias in conv layers (default=True). - - """ - - def __init__( - self, - channels: int, - kernel_size: int, - bias: bool = True, - is_pnnx: bool = False, - x_size: int = 0, - ) -> None: - """Construct an ConvolutionModule object.""" - super(ConvolutionModule, self).__init__() - # kernerl_size should be a odd number for 'SAME' padding - assert (kernel_size - 1) % 2 == 0, kernel_size - - self.pointwise_conv1 = nn.Conv1d( - channels, - 2 * channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, - ) - - # after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu). - # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, - # but sometimes, for some reason, for layer 0 the rms ends up being very large, - # between 50 and 100 for different channels. This will cause very peaky and - # sparse derivatives for the sigmoid gating function, which will tend to make - # the loss function not learn effectively. (for most layers the average absolute values - # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, - # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different - # layers, which likely breaks down as 0.5 for the "linear" half and - # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we - # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, - # it will be in a better position to start learning something, i.e. to latch onto - # the correct range. - self.deriv_balancer1 = ActivationBalancer( - 2 * channels, - channel_dim=1, - max_abs=10.0, - min_positive=0.05, - max_positive=1.0, - ) - - # Will pad cached left context - self.lorder = kernel_size - 1 - self.depthwise_conv = nn.Conv1d( - channels, - channels, - kernel_size, - stride=1, - padding=0, - groups=channels, - bias=bias, - ) - - self.deriv_balancer2 = ActivationBalancer( - channels, - channel_dim=1, - min_positive=0.05, - max_positive=1.0, - max_abs=20.0, - ) - - self.activation = DoubleSwish() - - self.pointwise_conv2 = ScaledConv1d( - channels, - channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, - initial_scale=0.05, - ) - - self.is_pnnx = is_pnnx - self.x_size = x_size - - def forward( - self, - x: Tensor, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - """Compute convolution module. - - Args: - x: Input tensor (#time, batch, channels). - src_key_padding_mask: the mask for the src keys per batch (optional): - (batch, #time), contains bool in masked positions. - - Returns: - - Output tensor (#time, batch, channels). - """ - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - - # GLU mechanism - x = self.pointwise_conv1(x) # (batch, 2*channels, time) - - x = self.deriv_balancer1(x) - x = nn.functional.glu(x, dim=1) # (batch, channels, time) - - if src_key_padding_mask is not None: - x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) - - # 1D Depthwise Conv - # Make depthwise_conv causal by - # manualy padding self.lorder zeros to the left - x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) - x = self.depthwise_conv(x) - - x = self.deriv_balancer2(x) - x = self.activation(x) - - x = self.pointwise_conv2(x) # (batch, channel, time) - - return x.permute(2, 0, 1) - - def streaming_forward( - self, - x: Tensor, - cache: Tensor, - ) -> Tuple[Tensor, Tensor]: - """Compute convolution module. - - Args: - x: Input tensor (#time, batch, channels). - src_key_padding_mask: the mask for the src keys per batch: - (batch, #time), contains bool in masked positions. - cache: Cached left context for depthwise_conv, with shape of - (batch, channels, #kernel_size-1). Only used in real streaming decoding. - - Returns: - A tuple of 2 tensors: - - Output tensor (#time, batch, channels). - - New cached left context, with shape of (batch, channels, #kernel_size-1). - """ - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - - # GLU mechanism - x = self.pointwise_conv1(x) # (batch, 2*channels, time) - - x = self.deriv_balancer1(x) - x = nn.functional.glu(x, dim=1) # (batch, channels, time) - - # 1D Depthwise Conv - assert cache.shape == (x.size(0), x.size(1), self.lorder), ( - cache.shape, - (x.size(0), x.size(1), self.lorder), - ) - x = torch.cat([cache, x], dim=2) - - cache = x[:, :, self.x_size :] - - x = self.depthwise_conv(x) - - x = self.deriv_balancer2(x) - x = self.activation(x) - - x = self.pointwise_conv2(x) # (batch, channel, time) - - return x.permute(2, 0, 1), cache - - -class Conv2dSubsampling(nn.Module): - """Convolutional 2D subsampling (to 1/4 length). - - Convert an input of shape (N, T, idim) to an output - with shape (N, T', odim), where - T' = (T-3)//2 - 2 == (T-7)//2 - - It is based on - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - layer1_channels: int = 8, - layer2_channels: int = 32, - layer3_channels: int = 128, - dropout: float = 0.1, - is_pnnx: bool = False, - ) -> None: - """ - Args: - in_channels: - Number of channels in. The input shape is (N, T, in_channels). - Caution: It requires: T >=7, in_channels >=7 - out_channels - Output dim. The output shape is (N, (T-7)//2, out_channels) - layer1_channels: - Number of channels in layer1 - layer2_channels: - Number of channels in layer2 - layer3_channels: - Number of channels in layer3 - is_pnnx: - True if we are converting the model to PNNX format. - False otherwise. - """ - assert in_channels >= 7, in_channels - super().__init__() - - self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, - out_channels=layer1_channels, - kernel_size=3, - padding=(0, 1), # (time, freq) - ), - # After this layer (N, 1, T, C) -> (N, layer1_channels, T-2, C) - ActivationBalancer(layer1_channels, channel_dim=1), - DoubleSwish(), - nn.Conv2d( - in_channels=layer1_channels, - out_channels=layer2_channels, - kernel_size=3, - stride=2, - padding=0, - ), - # After this layer (N, layer1_channels, T-2, C) -> (N, layer2_channels, ((T-2) - 3)//2+1, (C-3)//2+1) - # i.e., (N, layer2_channels, (T-5)//2+1, (C-3)//2+1) - # i.e., (N, layer2_channels, (T-3)//2, (C-1)//2) - ActivationBalancer(layer2_channels, channel_dim=1), - DoubleSwish(), - nn.Conv2d( - in_channels=layer2_channels, - out_channels=layer3_channels, - kernel_size=3, - stride=(1, 2), # (time, freq) - ), - # After this layer, (N, layer2_channels, (T-3)//2, (C-1)//2) - # -> - # (N, layer3_channels, (T-3)//2-2, ((C-1)//2 - 3)//2 + 1) - # (N, layer3_channels, (T-7)//2, (C-3)//4) - ActivationBalancer(layer3_channels, channel_dim=1), - DoubleSwish(), - ) - out_height = (((in_channels - 1) // 2) - 1) // 2 - self.out = ScaledLinear(out_height * layer3_channels, out_channels) - self.dropout = nn.Dropout(dropout) - - # ncnn supports only batch size == 1 - self.is_pnnx = is_pnnx - self.conv_out_dim = self.out.weight.shape[1] - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Subsample x. - - Args: - x: - Its shape is (N, T, idim). - - Returns: - Return a tensor of shape (N, (T-7)//2, odim) - """ - # On entry, x is (N, T, idim) - x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) - x = self.conv(x) - - if torch.jit.is_tracing() and self.is_pnnx: - x = x.permute(0, 2, 1, 3).reshape(1, -1, self.conv_out_dim) - x = self.out(x) - else: - # Now x is of shape (N, odim, (T-7)//2, ((idim-1)//2 - 1)//2) - b, c, t, f = x.size() - x = self.out(x.transpose(1, 2).reshape(b, t, c * f)) - # Now x is of shape (N, (T-7)//2, odim) - x = self.dropout(x) - return x - - -def _test_zipformer_main(): - feature_dim = 50 - batch_size = 5 - seq_len = 47 - feature_dim = 50 - # Just make sure the forward pass runs. - - c = Zipformer( - num_features=feature_dim, - encoder_dims=(64, 96), - encoder_unmasked_dims=(48, 64), - nhead=(4, 4), - decode_chunk_size=4, - ) - # Just make sure the forward pass runs. - f = c( - torch.randn(batch_size, seq_len, feature_dim), - torch.full((batch_size,), seq_len, dtype=torch.int64), - ) - assert ((seq_len - 7) // 2 + 1) // 2 == f[0].shape[1], (seq_len, f.shape[1]) - f[0].sum().backward() - c.eval() - f = c( - torch.randn(batch_size, seq_len, feature_dim), - torch.full((batch_size,), seq_len, dtype=torch.int64), - ) - f # to remove flake8 warnings - - -def _test_conv2d_subsampling(): - num_features = 80 - encoder_dims = 384 - dropout = 0.1 - encoder_embed = Conv2dSubsampling(num_features, encoder_dims, dropout=dropout) - for i in range(20, 40): - x = torch.rand(2, i, num_features) - y = encoder_embed(x) - assert (x.shape[1] - 7) // 2 == y.shape[1], (x.shape[1], y.shape[1]) - - -def _test_pooling_module(): - N, S, C = 2, 12, 32 - chunk_len = 4 - m = PoolingModule(d_model=C) - - # test chunk-wise forward with padding_mask - x = torch.randn(S, N, C) - y = m(x) - cached_len = torch.zeros(N, dtype=torch.int32) - cached_avg = torch.zeros(N, C) - for i in range(S // chunk_len): - start = i * chunk_len - end = start + chunk_len - x_chunk = x[start:end] - y_chunk, cached_len, cached_avg = m.streaming_forward( - x_chunk, - cached_len=cached_len, - cached_avg=cached_avg, - ) - assert torch.allclose(y_chunk, y[start:end]), (y_chunk, y[start:end]) - - -def _test_state_stack_unstack(): - m = Zipformer( - num_features=80, - encoder_dims=(64, 96), - encoder_unmasked_dims=(48, 64), - nhead=(4, 4), - zipformer_downsampling_factors=(4, 8), - num_left_chunks=2, - decode_chunk_size=8, - ) - s1 = m.get_init_state() - s2 = m.get_init_state() - states = stack_states([s1, s2]) - new_s1, new_s2 = unstack_states(states) - for i in range(m.num_encoders * 7): - for x, y in zip(s1[i], new_s1[i]): - assert torch.equal(x, y) - for x, y in zip(s2[i], new_s2[i]): - assert torch.equal(x, y) - - -if __name__ == "__main__": - logging.getLogger().setLevel(logging.INFO) - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - _test_zipformer_main() - _test_conv2d_subsampling() - _test_pooling_module() - _test_state_stack_unstack() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer2.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer2.py new file mode 120000 index 000000000..12dbda888 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer2.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py \ No newline at end of file diff --git a/icefall/shared/convert-k2-to-openfst.py b/icefall/shared/convert-k2-to-openfst.py deleted file mode 100755 index 29a2cd7f7..000000000 --- a/icefall/shared/convert-k2-to-openfst.py +++ /dev/null @@ -1,102 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script takes as input an FST in k2 format and convert it -to an FST in OpenFST format. - -The generated FST is saved into a binary file and its type is -StdVectorFst. - -Usage examples: -(1) Convert an acceptor - - ./convert-k2-to-openfst.py in.pt binary.fst - -(2) Convert a transducer - - ./convert-k2-to-openfst.py --olabels aux_labels in.pt binary.fst -""" - -import argparse -import logging -from pathlib import Path - -import k2 -import kaldifst.utils -import torch - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--olabels", - type=str, - default=None, - help="""If not empty, the input FST is assumed to be a transducer - and we use its attribute specified by "olabels" as the output labels. - """, - ) - parser.add_argument( - "input_filename", - type=str, - help="Path to the input FST in k2 format", - ) - - parser.add_argument( - "output_filename", - type=str, - help="Path to the output FST in OpenFst format", - ) - - return parser.parse_args() - - -def main(): - args = get_args() - logging.info(f"{vars(args)}") - - input_filename = args.input_filename - output_filename = args.output_filename - olabels = args.olabels - - if Path(output_filename).is_file(): - logging.info(f"{output_filename} already exists - skipping") - return - - assert Path(input_filename).is_file(), f"{input_filename} does not exist" - logging.info(f"Loading {input_filename}") - k2_fst = k2.Fsa.from_dict(torch.load(input_filename)) - if olabels: - assert hasattr(k2_fst, olabels), f"No such attribute: {olabels}" - - p = Path(output_filename).parent - if not p.is_dir(): - logging.info(f"Creating {p}") - p.mkdir(parents=True) - - logging.info("Converting (May take some time if the input FST is large)") - fst = kaldifst.utils.k2_to_openfst(k2_fst, olabels=olabels) - logging.info(f"Saving to {output_filename}") - fst.write(output_filename) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/icefall/shared/convert-k2-to-openfst.py b/icefall/shared/convert-k2-to-openfst.py new file mode 120000 index 000000000..24efe5eae --- /dev/null +++ b/icefall/shared/convert-k2-to-openfst.py @@ -0,0 +1 @@ +../../../librispeech/ASR/shared/convert-k2-to-openfst.py \ No newline at end of file diff --git a/icefall/shared/make_kn_lm.py b/icefall/shared/make_kn_lm.py deleted file mode 100755 index f2fd762d7..000000000 --- a/icefall/shared/make_kn_lm.py +++ /dev/null @@ -1,443 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright 2016 Johns Hopkins University (Author: Daniel Povey) -# 2018 Ruizhe Huang -# Apache 2.0. - -# This is an implementation of computing Kneser-Ney smoothed language model -# in the same way as srilm. This is a back-off, unmodified version of -# Kneser-Ney smoothing, which produces the same results as the following -# command (as an example) of srilm: -# -# $ ngram-count -order 4 -kn-modify-counts-at-end -ukndiscount -gt1min 0 -gt2min 0 -gt3min 0 -gt4min 0 \ -# -text corpus.txt -lm lm.arpa -# -# The data structure is based on: kaldi/egs/wsj/s5/utils/lang/make_phone_lm.py -# The smoothing algorithm is based on: http://www.speech.sri.com/projects/srilm/manpages/ngram-discount.7.html - -import argparse -import io -import math -import os -import re -import sys -from collections import Counter, defaultdict - -parser = argparse.ArgumentParser( - description=""" - Generate kneser-ney language model as arpa format. By default, - it will read the corpus from standard input, and output to standard output. - """ -) -parser.add_argument( - "-ngram-order", - type=int, - default=4, - choices=[2, 3, 4, 5, 6, 7], - help="Order of n-gram", -) -parser.add_argument("-text", type=str, default=None, help="Path to the corpus file") -parser.add_argument( - "-lm", type=str, default=None, help="Path to output arpa file for language models" -) -parser.add_argument( - "-verbose", type=int, default=0, choices=[0, 1, 2, 3, 4, 5], help="Verbose level" -) -args = parser.parse_args() - -# For encoding-agnostic scripts, we assume byte stream as input. -# Need to be very careful about the use of strip() and split() -# in this case, because there is a latin-1 whitespace character -# (nbsp) which is part of the unicode encoding range. -# Ref: kaldi/egs/wsj/s5/utils/lang/bpe/prepend_words.py @ 69cd717 -default_encoding = "latin-1" - -strip_chars = " \t\r\n" -whitespace = re.compile("[ \t]+") - - -class CountsForHistory: - # This class (which is more like a struct) stores the counts seen in a - # particular history-state. It is used inside class NgramCounts. - # It really does the job of a dict from int to float, but it also - # keeps track of the total count. - def __init__(self): - # The 'lambda: defaultdict(float)' is an anonymous function taking no - # arguments that returns a new defaultdict(float). - self.word_to_count = defaultdict(int) - # using a set to count the number of unique contexts - self.word_to_context = defaultdict(set) - self.word_to_f = dict() # discounted probability - self.word_to_bow = dict() # back-off weight - self.total_count = 0 - - def words(self): - return self.word_to_count.keys() - - def __str__(self): - # e.g. returns ' total=12: 3->4, 4->6, -1->2' - return " total={0}: {1}".format( - str(self.total_count), - ", ".join( - [ - "{0} -> {1}".format(word, count) - for word, count in self.word_to_count.items() - ] - ), - ) - - def add_count(self, predicted_word, context_word, count): - assert count >= 0 - - self.total_count += count - self.word_to_count[predicted_word] += count - if context_word is not None: - self.word_to_context[predicted_word].add(context_word) - - -class NgramCounts: - # A note on data-structure. Firstly, all words are represented as - # integers. We store n-gram counts as an array, indexed by (history-length - # == n-gram order minus one) (note: python calls arrays "lists") of dicts - # from histories to counts, where histories are arrays of integers and - # "counts" are dicts from integer to float. For instance, when - # accumulating the 4-gram count for the '8' in the sequence '5 6 7 8', we'd - # do as follows: self.counts[3][[5,6,7]][8] += 1.0 where the [3] indexes an - # array, the [[5,6,7]] indexes a dict, and the [8] indexes a dict. - def __init__(self, ngram_order, bos_symbol="", eos_symbol=""): - assert ngram_order >= 2 - - self.ngram_order = ngram_order - self.bos_symbol = bos_symbol - self.eos_symbol = eos_symbol - - self.counts = [] - for n in range(ngram_order): - self.counts.append(defaultdict(lambda: CountsForHistory())) - - self.d = [] # list of discounting factor for each order of ngram - - # adds a raw count (called while processing input data). - # Suppose we see the sequence '6 7 8 9' and ngram_order=4, 'history' - # would be (6,7,8) and 'predicted_word' would be 9; 'count' would be - # 1. - def add_count(self, history, predicted_word, context_word, count): - self.counts[len(history)][history].add_count( - predicted_word, context_word, count - ) - - # 'line' is a string containing a sequence of integer word-ids. - # This function adds the un-smoothed counts from this line of text. - def add_raw_counts_from_line(self, line): - if line == "": - words = [self.bos_symbol, self.eos_symbol] - else: - words = [self.bos_symbol] + whitespace.split(line) + [self.eos_symbol] - - for i in range(len(words)): - for n in range(1, self.ngram_order + 1): - if i + n > len(words): - break - ngram = words[i : i + n] - predicted_word = ngram[-1] - history = tuple(ngram[:-1]) - if i == 0 or n == self.ngram_order: - context_word = None - else: - context_word = words[i - 1] - - self.add_count(history, predicted_word, context_word, 1) - - def add_raw_counts_from_standard_input(self): - lines_processed = 0 - # byte stream as input - infile = io.TextIOWrapper(sys.stdin.buffer, encoding=default_encoding) - for line in infile: - line = line.strip(strip_chars) - self.add_raw_counts_from_line(line) - lines_processed += 1 - if lines_processed == 0 or args.verbose > 0: - print( - "make_phone_lm.py: processed {0} lines of input".format( - lines_processed - ), - file=sys.stderr, - ) - - def add_raw_counts_from_file(self, filename): - lines_processed = 0 - with open(filename, encoding=default_encoding) as fp: - for line in fp: - line = line.strip(strip_chars) - self.add_raw_counts_from_line(line) - lines_processed += 1 - if lines_processed == 0 or args.verbose > 0: - print( - "make_phone_lm.py: processed {0} lines of input".format( - lines_processed - ), - file=sys.stderr, - ) - - def cal_discounting_constants(self): - # For each order N of N-grams, we calculate discounting constant D_N = n1_N / (n1_N + 2 * n2_N), - # where n1_N is the number of unique N-grams with count = 1 (counts-of-counts). - # This constant is used similarly to absolute discounting. - # Return value: d is a list of floats, where d[N+1] = D_N - - # for the lowest order, i.e., 1-gram, we do not need to discount, thus the constant is 0 - # This is a special case: as we currently assumed having seen all vocabularies in the dictionary, - # but perhaps this is not the case for some other scenarios. - self.d = [0] - for n in range(1, self.ngram_order): - this_order_counts = self.counts[n] - n1 = 0 - n2 = 0 - for hist, counts_for_hist in this_order_counts.items(): - stat = Counter(counts_for_hist.word_to_count.values()) - n1 += stat[1] - n2 += stat[2] - assert n1 + 2 * n2 > 0 - - # We are doing this max(0.001, xxx) to avoid zero discounting constant D due to n1=0, - # which could happen if the number of symbols is small. - # Otherwise, zero discounting constant can cause division by zero in computing BOW. - self.d.append(max(0.1, n1 * 1.0) / (n1 + 2 * n2)) - - def cal_f(self): - # f(a_z) is a probability distribution of word sequence a_z. - # Typically f(a_z) is discounted to be less than the ML estimate so we have - # some leftover probability for the z words unseen in the context (a_). - # - # f(a_z) = (c(a_z) - D0) / c(a_) ;; for highest order N-grams - # f(_z) = (n(*_z) - D1) / n(*_*) ;; for lower order N-grams - - # highest order N-grams - n = self.ngram_order - 1 - this_order_counts = self.counts[n] - for hist, counts_for_hist in this_order_counts.items(): - for w, c in counts_for_hist.word_to_count.items(): - counts_for_hist.word_to_f[w] = ( - max((c - self.d[n]), 0) * 1.0 / counts_for_hist.total_count - ) - - # lower order N-grams - for n in range(0, self.ngram_order - 1): - this_order_counts = self.counts[n] - for hist, counts_for_hist in this_order_counts.items(): - - n_star_star = 0 - for w in counts_for_hist.word_to_count.keys(): - n_star_star += len(counts_for_hist.word_to_context[w]) - - if n_star_star != 0: - for w in counts_for_hist.word_to_count.keys(): - n_star_z = len(counts_for_hist.word_to_context[w]) - counts_for_hist.word_to_f[w] = ( - max((n_star_z - self.d[n]), 0) * 1.0 / n_star_star - ) - else: # patterns begin with , they do not have "modified count", so use raw count instead - for w in counts_for_hist.word_to_count.keys(): - n_star_z = counts_for_hist.word_to_count[w] - counts_for_hist.word_to_f[w] = ( - max((n_star_z - self.d[n]), 0) - * 1.0 - / counts_for_hist.total_count - ) - - def cal_bow(self): - # Backoff weights are only necessary for ngrams which form a prefix of a longer ngram. - # Thus, two sorts of ngrams do not have a bow: - # 1) highest order ngram - # 2) ngrams ending in - # - # bow(a_) = (1 - Sum_Z1 f(a_z)) / (1 - Sum_Z1 f(_z)) - # Note that Z1 is the set of all words with c(a_z) > 0 - - # highest order N-grams - n = self.ngram_order - 1 - this_order_counts = self.counts[n] - for hist, counts_for_hist in this_order_counts.items(): - for w in counts_for_hist.word_to_count.keys(): - counts_for_hist.word_to_bow[w] = None - - # lower order N-grams - for n in range(0, self.ngram_order - 1): - this_order_counts = self.counts[n] - for hist, counts_for_hist in this_order_counts.items(): - for w in counts_for_hist.word_to_count.keys(): - if w == self.eos_symbol: - counts_for_hist.word_to_bow[w] = None - else: - a_ = hist + (w,) - - assert len(a_) < self.ngram_order - assert a_ in self.counts[len(a_)].keys() - - a_counts_for_hist = self.counts[len(a_)][a_] - - sum_z1_f_a_z = 0 - for u in a_counts_for_hist.word_to_count.keys(): - sum_z1_f_a_z += a_counts_for_hist.word_to_f[u] - - sum_z1_f_z = 0 - _ = a_[1:] - _counts_for_hist = self.counts[len(_)][_] - # Should be careful here: what is Z1 - for u in a_counts_for_hist.word_to_count.keys(): - sum_z1_f_z += _counts_for_hist.word_to_f[u] - - if sum_z1_f_z < 1: - # assert sum_z1_f_a_z < 1 - counts_for_hist.word_to_bow[w] = (1.0 - sum_z1_f_a_z) / ( - 1.0 - sum_z1_f_z - ) - else: - counts_for_hist.word_to_bow[w] = None - - def print_raw_counts(self, info_string): - # these are useful for debug. - print(info_string) - res = [] - for this_order_counts in self.counts: - for hist, counts_for_hist in this_order_counts.items(): - for w in counts_for_hist.word_to_count.keys(): - ngram = " ".join(hist) + " " + w - ngram = ngram.strip(strip_chars) - - res.append( - "{0}\t{1}".format(ngram, counts_for_hist.word_to_count[w]) - ) - res.sort(reverse=True) - for r in res: - print(r) - - def print_modified_counts(self, info_string): - # these are useful for debug. - print(info_string) - res = [] - for this_order_counts in self.counts: - for hist, counts_for_hist in this_order_counts.items(): - for w in counts_for_hist.word_to_count.keys(): - ngram = " ".join(hist) + " " + w - ngram = ngram.strip(strip_chars) - - modified_count = len(counts_for_hist.word_to_context[w]) - raw_count = counts_for_hist.word_to_count[w] - - if modified_count == 0: - res.append("{0}\t{1}".format(ngram, raw_count)) - else: - res.append("{0}\t{1}".format(ngram, modified_count)) - res.sort(reverse=True) - for r in res: - print(r) - - def print_f(self, info_string): - # these are useful for debug. - print(info_string) - res = [] - for this_order_counts in self.counts: - for hist, counts_for_hist in this_order_counts.items(): - for w in counts_for_hist.word_to_count.keys(): - ngram = " ".join(hist) + " " + w - ngram = ngram.strip(strip_chars) - - f = counts_for_hist.word_to_f[w] - if f == 0: # f() is always 0 - f = 1e-99 - - res.append("{0}\t{1}".format(ngram, math.log(f, 10))) - res.sort(reverse=True) - for r in res: - print(r) - - def print_f_and_bow(self, info_string): - # these are useful for debug. - print(info_string) - res = [] - for this_order_counts in self.counts: - for hist, counts_for_hist in this_order_counts.items(): - for w in counts_for_hist.word_to_count.keys(): - ngram = " ".join(hist) + " " + w - ngram = ngram.strip(strip_chars) - - f = counts_for_hist.word_to_f[w] - if f == 0: # f() is always 0 - f = 1e-99 - - bow = counts_for_hist.word_to_bow[w] - if bow is None: - res.append("{1}\t{0}".format(ngram, math.log(f, 10))) - else: - res.append( - "{1}\t{0}\t{2}".format( - ngram, math.log(f, 10), math.log(bow, 10) - ) - ) - res.sort(reverse=True) - for r in res: - print(r) - - def print_as_arpa( - self, fout=io.TextIOWrapper(sys.stdout.buffer, encoding="latin-1") - ): - # print as ARPA format. - - print("\\data\\", file=fout) - for hist_len in range(self.ngram_order): - # print the number of n-grams. - print( - "ngram {0}={1}".format( - hist_len + 1, - sum( - [ - len(counts_for_hist.word_to_f) - for counts_for_hist in self.counts[hist_len].values() - ] - ), - ), - file=fout, - ) - - print("", file=fout) - - for hist_len in range(self.ngram_order): - print("\\{0}-grams:".format(hist_len + 1), file=fout) - - this_order_counts = self.counts[hist_len] - for hist, counts_for_hist in this_order_counts.items(): - for word in counts_for_hist.word_to_count.keys(): - ngram = hist + (word,) - prob = counts_for_hist.word_to_f[word] - bow = counts_for_hist.word_to_bow[word] - - if prob == 0: # f() is always 0 - prob = 1e-99 - - line = "{0}\t{1}".format("%.7f" % math.log10(prob), " ".join(ngram)) - if bow is not None: - line += "\t{0}".format("%.7f" % math.log10(bow)) - print(line, file=fout) - print("", file=fout) - print("\\end\\", file=fout) - - -if __name__ == "__main__": - - ngram_counts = NgramCounts(args.ngram_order) - if args.text is None: - ngram_counts.add_raw_counts_from_standard_input() - else: - assert os.path.isfile(args.text) - ngram_counts.add_raw_counts_from_file(args.text) - - ngram_counts.cal_discounting_constants() - ngram_counts.cal_f() - ngram_counts.cal_bow() - - if args.lm is None: - ngram_counts.print_as_arpa() - else: - with open(args.lm, "w", encoding=default_encoding) as f: - ngram_counts.print_as_arpa(fout=f) diff --git a/icefall/shared/make_kn_lm.py b/icefall/shared/make_kn_lm.py new file mode 120000 index 000000000..6f6032470 --- /dev/null +++ b/icefall/shared/make_kn_lm.py @@ -0,0 +1 @@ +../../../librispeech/ASR/shared/make_kn_lm.py \ No newline at end of file diff --git a/icefall/shared/ngram_entropy_pruning.py b/icefall/shared/ngram_entropy_pruning.py deleted file mode 100755 index b1ebee9ea..000000000 --- a/icefall/shared/ngram_entropy_pruning.py +++ /dev/null @@ -1,630 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# -# Copyright 2021 Johns Hopkins University (Author: Ruizhe Huang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -./ngram_entropy_pruning.py \ - -threshold 1e-8 \ - -lm download/lm/4gram.arpa \ - -write-lm download/lm/4gram_pruned_1e8.arpa - -This file is from Kaldi `egs/wsj/s5/utils/lang/ngram_entropy_pruning.py`. -This is an implementation of ``Entropy-based Pruning of Backoff Language Models'' -in the same way as SRILM. -""" - - -import argparse -import gzip -import logging -import math -import re -from collections import OrderedDict, defaultdict -from enum import Enum, unique -from io import StringIO - -parser = argparse.ArgumentParser( - description=""" - Prune an n-gram language model based on the relative entropy - between the original and the pruned model, based on Andreas Stolcke's paper. - An n-gram entry is removed, if the removal causes (training set) perplexity - of the model to increase by less than threshold relative. - - The command takes an arpa file and a pruning threshold as input, - and outputs a pruned arpa file. - """ -) -parser.add_argument("-threshold", type=float, default=1e-6, help="Order of n-gram") -parser.add_argument("-lm", type=str, default=None, help="Path to the input arpa file") -parser.add_argument( - "-write-lm", type=str, default=None, help="Path to output arpa file after pruning" -) -parser.add_argument( - "-minorder", - type=int, - default=1, - help="The minorder parameter limits pruning to ngrams of that length and above.", -) -parser.add_argument( - "-encoding", type=str, default="utf-8", help="Encoding of the arpa file" -) -parser.add_argument( - "-verbose", - type=int, - default=2, - choices=[0, 1, 2, 3, 4, 5], - help="Verbose level, where 0 is most noisy; 5 is most silent", -) -args = parser.parse_args() - -default_encoding = args.encoding -logging.basicConfig( - format="%(asctime)s — %(levelname)s — %(funcName)s:%(lineno)d — %(message)s", - level=args.verbose * 10, -) - - -class Context(dict): - """ - This class stores data for a context h. - It behaves like a python dict object, except that it has several - additional attributes. - """ - - def __init__(self): - super().__init__() - self.log_bo = None - - -class Arpa: - """ - This is a class that implement the data structure of an APRA LM. - It (as well as some other classes) is modified based on the library - by Stefan Fischer: - https://github.com/sfischer13/python-arpa - """ - - UNK = "" - SOS = "" - EOS = "" - FLOAT_NDIGITS = 7 - base = 10 - - @staticmethod - def _check_input(my_input): - if not my_input: - raise ValueError - elif isinstance(my_input, tuple): - return my_input - elif isinstance(my_input, list): - return tuple(my_input) - elif isinstance(my_input, str): - return tuple(my_input.strip().split(" ")) - else: - raise ValueError - - @staticmethod - def _check_word(input_word): - if not isinstance(input_word, str): - raise ValueError - if " " in input_word: - raise ValueError - - def _replace_unks(self, words): - return tuple((w if w in self else self._unk) for w in words) - - def __init__(self, path=None, encoding=None, unk=None): - self._counts = OrderedDict() - self._ngrams = ( - OrderedDict() - ) # Use self._ngrams[len(h)][h][w] for saving the entry of (h,w) - self._vocabulary = set() - if unk is None: - self._unk = self.UNK - - if path is not None: - self.loadf(path, encoding) - - def __contains__(self, ngram): - h = ngram[:-1] # h is a tuple - w = ngram[-1] # w is a string/word - return h in self._ngrams[len(h)] and w in self._ngrams[len(h)][h] - - def contains_word(self, word): - self._check_word(word) - return word in self._vocabulary - - def add_count(self, order, count): - self._counts[order] = count - self._ngrams[order - 1] = defaultdict(Context) - - def update_counts(self): - for order in range(1, self.order() + 1): - count = sum([len(wlist) for _, wlist in self._ngrams[order - 1].items()]) - if count > 0: - self._counts[order] = count - - def add_entry(self, ngram, p, bo=None, order=None): - # Note: ngram is a tuple of strings, e.g. ("w1", "w2", "w3") - h = ngram[:-1] # h is a tuple - w = ngram[-1] # w is a string/word - - # Note that p and bo here are in fact in the log domain (self.base = 10) - h_context = self._ngrams[len(h)][h] - h_context[w] = p - if bo is not None: - self._ngrams[len(ngram)][ngram].log_bo = bo - - for word in ngram: - self._vocabulary.add(word) - - def counts(self): - return sorted(self._counts.items()) - - def order(self): - return max(self._counts.keys(), default=None) - - def vocabulary(self, sort=True): - if sort: - return sorted(self._vocabulary) - else: - return self._vocabulary - - def _entries(self, order): - return ( - self._entry(h, w) - for h, wlist in self._ngrams[order - 1].items() - for w in wlist - ) - - def _entry(self, h, w): - # return the entry for the ngram (h, w) - ngram = h + (w,) - log_p = self._ngrams[len(h)][h][w] - log_bo = self._log_bo(ngram) - if log_bo is not None: - return ( - round(log_p, self.FLOAT_NDIGITS), - ngram, - round(log_bo, self.FLOAT_NDIGITS), - ) - else: - return round(log_p, self.FLOAT_NDIGITS), ngram - - def _log_bo(self, ngram): - if len(ngram) in self._ngrams and ngram in self._ngrams[len(ngram)]: - return self._ngrams[len(ngram)][ngram].log_bo - else: - return None - - def _log_p(self, ngram): - h = ngram[:-1] # h is a tuple - w = ngram[-1] # w is a string/word - if h in self._ngrams[len(h)] and w in self._ngrams[len(h)][h]: - return self._ngrams[len(h)][h][w] - else: - return None - - def log_p_raw(self, ngram): - log_p = self._log_p(ngram) - if log_p is not None: - return log_p - else: - if len(ngram) == 1: - raise KeyError - else: - log_bo = self._log_bo(ngram[:-1]) - if log_bo is None: - log_bo = 0 - return log_bo + self.log_p_raw(ngram[1:]) - - def log_joint_prob(self, sequence): - # Compute the joint prob of the sequence based on the chain rule - # Note that sequence should be a tuple of strings - # - # Reference: - # https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/LM.cc#L527 - - log_joint_p = 0 - seq = sequence - while len(seq) > 0: - log_joint_p += self.log_p_raw(seq) - seq = seq[:-1] - - # If we're computing the marginal probability of the unigram - # context we have to look up instead since the former - # has prob = 0. - if len(seq) == 1 and seq[0] == self.SOS: - seq = (self.EOS,) - - return log_joint_p - - def set_new_context(self, h): - old_context = self._ngrams[len(h)][h] - self._ngrams[len(h)][h] = Context() - return old_context - - def log_p(self, ngram): - words = self._check_input(ngram) - if self._unk: - words = self._replace_unks(words) - return self.log_p_raw(words) - - def log_s(self, sentence, sos=SOS, eos=EOS): - words = self._check_input(sentence) - if self._unk: - words = self._replace_unks(words) - if sos: - words = (sos,) + words - if eos: - words = words + (eos,) - result = sum(self.log_p_raw(words[:i]) for i in range(1, len(words) + 1)) - if sos: - result = result - self.log_p_raw(words[:1]) - return result - - def p(self, ngram): - return self.base ** self.log_p(ngram) - - def s(self, sentence): - return self.base ** self.log_s(sentence) - - def write(self, fp): - fp.write("\n\\data\\\n") - for order, count in self.counts(): - fp.write("ngram {}={}\n".format(order, count)) - fp.write("\n") - for order, _ in self.counts(): - fp.write("\\{}-grams:\n".format(order)) - for e in self._entries(order): - prob = e[0] - ngram = " ".join(e[1]) - if len(e) == 2: - fp.write("{}\t{}\n".format(prob, ngram)) - elif len(e) == 3: - backoff = e[2] - fp.write("{}\t{}\t{}\n".format(prob, ngram, backoff)) - else: - raise ValueError - fp.write("\n") - fp.write("\\end\\\n") - - -class ArpaParser: - """ - This is a class that implement a parser of an arpa file - """ - - @unique - class State(Enum): - DATA = 1 - COUNT = 2 - HEADER = 3 - ENTRY = 4 - - re_count = re.compile(r"^ngram (\d+)=(\d+)$") - re_header = re.compile(r"^\\(\d+)-grams:$") - re_entry = re.compile( - "^(-?\\d+(\\.\\d+)?([eE]-?\\d+)?)" - "\t" - "(\\S+( \\S+)*)" - "(\t((-?\\d+(\\.\\d+)?)([eE]-?\\d+)?))?$" - ) - - def _parse(self, fp): - self._result = [] - self._state = self.State.DATA - self._tmp_model = None - self._tmp_order = None - for line in fp: - line = line.strip() - if self._state == self.State.DATA: - self._data(line) - elif self._state == self.State.COUNT: - self._count(line) - elif self._state == self.State.HEADER: - self._header(line) - elif self._state == self.State.ENTRY: - self._entry(line) - if self._state != self.State.DATA: - raise Exception(line) - return self._result - - def _data(self, line): - if line == "\\data\\": - self._state = self.State.COUNT - self._tmp_model = Arpa() - else: - pass # skip comment line - - def _count(self, line): - match = self.re_count.match(line) - if match: - order = match.group(1) - count = match.group(2) - self._tmp_model.add_count(int(order), int(count)) - elif not line: - self._state = self.State.HEADER # there are no counts - else: - raise Exception(line) - - def _header(self, line): - match = self.re_header.match(line) - if match: - self._state = self.State.ENTRY - self._tmp_order = int(match.group(1)) - elif line == "\\end\\": - self._result.append(self._tmp_model) - self._state = self.State.DATA - self._tmp_model = None - self._tmp_order = None - elif not line: - pass # skip empty line - else: - raise Exception(line) - - def _entry(self, line): - match = self.re_entry.match(line) - if match: - p = self._float_or_int(match.group(1)) - ngram = tuple(match.group(4).split(" ")) - bo_match = match.group(7) - bo = self._float_or_int(bo_match) if bo_match else None - self._tmp_model.add_entry(ngram, p, bo, self._tmp_order) - elif not line: - self._state = self.State.HEADER # last entry - else: - raise Exception(line) - - @staticmethod - def _float_or_int(s): - f = float(s) - i = int(f) - if str(i) == s: # don't drop trailing ".0" - return i - else: - return f - - def load(self, fp): - """Deserialize fp (a file-like object) to a Python object.""" - return self._parse(fp) - - def loadf(self, path, encoding=None): - """Deserialize path (.arpa, .gz) to a Python object.""" - path = str(path) - if path.endswith(".gz"): - with gzip.open(path, mode="rt", encoding=encoding) as f: - return self.load(f) - else: - with open(path, mode="rt", encoding=encoding) as f: - return self.load(f) - - def loads(self, s): - """Deserialize s (a str) to a Python object.""" - with StringIO(s) as f: - return self.load(f) - - def dump(self, obj, fp): - """Serialize obj to fp (a file-like object) in ARPA format.""" - obj.write(fp) - - def dumpf(self, obj, path, encoding=None): - """Serialize obj to path in ARPA format (.arpa, .gz).""" - path = str(path) - if path.endswith(".gz"): - with gzip.open(path, mode="wt", encoding=encoding) as f: - return self.dump(obj, f) - else: - with open(path, mode="wt", encoding=encoding) as f: - self.dump(obj, f) - - def dumps(self, obj): - """Serialize obj to an ARPA formatted str.""" - with StringIO() as f: - self.dump(obj, f) - return f.getvalue() - - -def add_log_p(prev_log_sum, log_p, base): - return math.log(base**log_p + base**prev_log_sum, base) - - -def compute_numerator_denominator(lm, h): - log_sum_seen_h = -math.inf - log_sum_seen_h_lower = -math.inf - base = lm.base - for w, log_p in lm._ngrams[len(h)][h].items(): - log_sum_seen_h = add_log_p(log_sum_seen_h, log_p, base) - - ngram = h + (w,) - log_p_lower = lm.log_p_raw(ngram[1:]) - log_sum_seen_h_lower = add_log_p(log_sum_seen_h_lower, log_p_lower, base) - - numerator = 1.0 - base**log_sum_seen_h - denominator = 1.0 - base**log_sum_seen_h_lower - return numerator, denominator - - -def prune(lm, threshold, minorder): - # Reference: - # https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/NgramLM.cc#L2330 - - for i in range( - lm.order(), max(minorder - 1, 1), -1 - ): # i is the order of the ngram (h, w) - logging.info("processing %d-grams ..." % i) - count_pruned_ngrams = 0 - - h_dict = lm._ngrams[i - 1] - for h in list(h_dict.keys()): - # old backoff weight, BOW(h) - log_bow = lm._log_bo(h) - if log_bow is None: - log_bow = 0 - - # Compute numerator and denominator of the backoff weight, - # so that we can quickly compute the BOW adjustment due to - # leaving out one prob. - numerator, denominator = compute_numerator_denominator(lm, h) - - # assert abs(math.log(numerator, lm.base) - math.log(denominator, lm.base) - h_dict[h].log_bo) < 1e-5 - - # Compute the marginal probability of the context, P(h) - h_log_p = lm.log_joint_prob(h) - - all_pruned = True - pruned_w_set = set() - - for w, log_p in h_dict[h].items(): - ngram = h + (w,) - - # lower-order estimate for ngramProb, P(w|h') - backoff_prob = lm.log_p_raw(ngram[1:]) - - # Compute BOW after removing ngram, BOW'(h) - new_log_bow = math.log( - numerator + lm.base**log_p, lm.base - ) - math.log(denominator + lm.base**backoff_prob, lm.base) - - # Compute change in entropy due to removal of ngram - delta_prob = backoff_prob + new_log_bow - log_p - delta_entropy = -(lm.base**h_log_p) * ( - (lm.base**log_p) * delta_prob - + numerator * (new_log_bow - log_bow) - ) - - # compute relative change in model (training set) perplexity - perp_change = lm.base**delta_entropy - 1.0 - - pruned = threshold > 0 and perp_change < threshold - - # Make sure we don't prune ngrams whose backoff nodes are needed - if ( - pruned - and len(ngram) in lm._ngrams - and len(lm._ngrams[len(ngram)][ngram]) > 0 - ): - pruned = False - - logging.debug( - "CONTEXT " - + str(h) - + " WORD " - + w - + " CONTEXTPROB %f " % h_log_p - + " OLDPROB %f " % log_p - + " NEWPROB %f " % (backoff_prob + new_log_bow) - + " DELTA-H %f " % delta_entropy - + " DELTA-LOGP %f " % delta_prob - + " PPL-CHANGE %f " % perp_change - + " PRUNED " - + str(pruned) - ) - - if pruned: - pruned_w_set.add(w) - count_pruned_ngrams += 1 - else: - all_pruned = False - - # If we removed all ngrams for this context we can - # remove the context itself, but only if the present - # context is not a prefix to a longer one. - if all_pruned and len(pruned_w_set) == len(h_dict[h]): - del h_dict[ - h - ] # this context h is no longer needed, as its ngram prob is stored at its own context h' - elif len(pruned_w_set) > 0: - # The pruning for this context h is actually done here - old_context = lm.set_new_context(h) - - for w, p_w in old_context.items(): - if w not in pruned_w_set: - lm.add_entry( - h + (w,), p_w - ) # the entry hw is stored at the context h - - # We need to recompute the back-off weight, but - # this can only be done after completing the pruning - # of the lower-order ngrams. - # Reference: - # https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/flm/src/FNgramLM.cc#L2124 - - logging.info("pruned %d %d-grams" % (count_pruned_ngrams, i)) - - # recompute backoff weights - for i in range( - max(minorder - 1, 1) + 1, lm.order() + 1 - ): # be careful of this order: from low- to high-order - for h in lm._ngrams[i - 1]: - numerator, denominator = compute_numerator_denominator(lm, h) - new_log_bow = math.log(numerator, lm.base) - math.log(denominator, lm.base) - lm._ngrams[len(h)][h].log_bo = new_log_bow - - # update counts - lm.update_counts() - - return - - -def check_h_is_valid(lm, h): - sum_under_h = sum( - [lm.base ** lm.log_p_raw(h + (w,)) for w in lm.vocabulary(sort=False)] - ) - if abs(sum_under_h - 1.0) > 1e-6: - logging.info("warning: %s %f" % (str(h), sum_under_h)) - return False - else: - return True - - -def validate_lm(lm): - # sanity check if the conditional probability sums to one under each context h - for i in range(lm.order(), 0, -1): # i is the order of the ngram (h, w) - logging.info("validating %d-grams ..." % i) - h_dict = lm._ngrams[i - 1] - for h in h_dict.keys(): - check_h_is_valid(lm, h) - - -def compare_two_apras(path1, path2): - pass - - -if __name__ == "__main__": - # load an arpa file - logging.info("Loading the arpa file from %s" % args.lm) - parser = ArpaParser() - models = parser.loadf(args.lm, encoding=default_encoding) - lm = models[0] # ARPA files may contain several models. - logging.info("Stats before pruning:") - for i, cnt in lm.counts(): - logging.info("ngram %d=%d" % (i, cnt)) - - # prune it, the language model will be modified in-place - logging.info("Start pruning the model with threshold=%.3E..." % args.threshold) - prune(lm, args.threshold, args.minorder) - - # validate_lm(lm) - - # write the arpa language model to a file - logging.info("Stats after pruning:") - for i, cnt in lm.counts(): - logging.info("ngram %d=%d" % (i, cnt)) - logging.info("Saving the pruned arpa file to %s" % args.write_lm) - parser.dumpf(lm, args.write_lm, encoding=default_encoding) - logging.info("Done.") diff --git a/icefall/shared/ngram_entropy_pruning.py b/icefall/shared/ngram_entropy_pruning.py new file mode 120000 index 000000000..0e14ac415 --- /dev/null +++ b/icefall/shared/ngram_entropy_pruning.py @@ -0,0 +1 @@ +../../../librispeech/ASR/shared/ngram_entropy_pruning.py \ No newline at end of file diff --git a/icefall/shared/parse_options.sh b/icefall/shared/parse_options.sh deleted file mode 100755 index 71fb9e5ea..000000000 --- a/icefall/shared/parse_options.sh +++ /dev/null @@ -1,97 +0,0 @@ -#!/usr/bin/env bash - -# Copyright 2012 Johns Hopkins University (Author: Daniel Povey); -# Arnab Ghoshal, Karel Vesely - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -# MERCHANTABLITY OR NON-INFRINGEMENT. -# See the Apache 2 License for the specific language governing permissions and -# limitations under the License. - - -# Parse command-line options. -# To be sourced by another script (as in ". parse_options.sh"). -# Option format is: --option-name arg -# and shell variable "option_name" gets set to value "arg." -# The exception is --help, which takes no arguments, but prints the -# $help_message variable (if defined). - - -### -### The --config file options have lower priority to command line -### options, so we need to import them first... -### - -# Now import all the configs specified by command-line, in left-to-right order -for ((argpos=1; argpos<$#; argpos++)); do - if [ "${!argpos}" == "--config" ]; then - argpos_plus1=$((argpos+1)) - config=${!argpos_plus1} - [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 - . $config # source the config file. - fi -done - - -### -### Now we process the command line options -### -while true; do - [ -z "${1:-}" ] && break; # break if there are no arguments - case "$1" in - # If the enclosing script is called with --help option, print the help - # message and exit. Scripts should put help messages in $help_message - --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; - else printf "$help_message\n" 1>&2 ; fi; - exit 0 ;; - --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" - exit 1 ;; - # If the first command-line argument begins with "--" (e.g. --foo-bar), - # then work out the variable name as $name, which will equal "foo_bar". - --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; - # Next we test whether the variable in question is undefned-- if so it's - # an invalid option and we die. Note: $0 evaluates to the name of the - # enclosing script. - # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar - # is undefined. We then have to wrap this test inside "eval" because - # foo_bar is itself inside a variable ($name). - eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; - - oldval="`eval echo \\$$name`"; - # Work out whether we seem to be expecting a Boolean argument. - if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then - was_bool=true; - else - was_bool=false; - fi - - # Set the variable to the right value-- the escaped quotes make it work if - # the option had spaces, like --cmd "queue.pl -sync y" - eval $name=\"$2\"; - - # Check that Boolean-valued arguments are really Boolean. - if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then - echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 - exit 1; - fi - shift 2; - ;; - *) break; - esac -done - - -# Check for an empty argument to the --cmd option, which can easily occur as a -# result of scripting errors. -[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; - - -true; # so this script returns exit code 0. diff --git a/icefall/shared/parse_options.sh b/icefall/shared/parse_options.sh new file mode 120000 index 000000000..e4665e7de --- /dev/null +++ b/icefall/shared/parse_options.sh @@ -0,0 +1 @@ +../../../librispeech/ASR/shared/parse_options.sh \ No newline at end of file