From 748db76648947cafc3a4144d3476ad330d06e189 Mon Sep 17 00:00:00 2001 From: JinZr <60612200+JinZr@users.noreply.github.com> Date: Fri, 21 Jul 2023 00:50:01 +0800 Subject: [PATCH] updated --- egs/multi_zh-hans/ASR/prepare.sh | 17 +- egs/multi_zh-hans/ASR/zipformer/__init__.py | 0 .../ASR/zipformer/asr_datamodule.py | 156 +- .../ASR/zipformer/beam_search.py | 2908 ++++++++++++++++- egs/multi_zh-hans/ASR/zipformer/decode.py | 318 +- .../ASR/zipformer/decode_stream.py | 149 +- egs/multi_zh-hans/ASR/zipformer/decoder.py | 123 +- .../ASR/zipformer/encoder_interface.py | 44 +- .../ASR/zipformer/export-onnx-streaming.py | 777 ++++- .../ASR/zipformer/export-onnx.py | 622 +++- egs/multi_zh-hans/ASR/zipformer/export.py | 546 +++- .../ASR/zipformer/jit_pretrained.py | 281 +- .../ASR/zipformer/jit_pretrained_streaming.py | 274 +- egs/multi_zh-hans/ASR/zipformer/joiner.py | 67 +- egs/multi_zh-hans/ASR/zipformer/model.py | 359 +- .../ASR/zipformer/multi_dataset.py | 229 +- egs/multi_zh-hans/ASR/zipformer/onnx_check.py | 242 +- .../ASR/zipformer/onnx_decode.py | 105 +- .../zipformer/onnx_pretrained-streaming.py | 545 ++- .../ASR/zipformer/onnx_pretrained.py | 420 ++- egs/multi_zh-hans/ASR/zipformer/optim.py | 1174 ++++++- egs/multi_zh-hans/ASR/zipformer/pretrained.py | 382 ++- egs/multi_zh-hans/ASR/zipformer/scaling.py | 1841 ++++++++++- .../ASR/zipformer/scaling_converter.py | 105 +- .../ASR/zipformer/streaming_beam_search.py | 296 +- .../ASR/zipformer/streaming_decode.py | 145 +- .../ASR/zipformer/subsampling.py | 408 ++- egs/multi_zh-hans/ASR/zipformer/train.py | 282 +- egs/multi_zh-hans/ASR/zipformer/zipformer.py | 2255 ++++++++++++- 29 files changed, 14523 insertions(+), 547 deletions(-) delete mode 100644 egs/multi_zh-hans/ASR/zipformer/__init__.py mode change 120000 => 100644 egs/multi_zh-hans/ASR/zipformer/beam_search.py mode change 120000 => 100644 egs/multi_zh-hans/ASR/zipformer/decode_stream.py mode change 120000 => 100644 egs/multi_zh-hans/ASR/zipformer/decoder.py mode change 120000 => 100644 egs/multi_zh-hans/ASR/zipformer/encoder_interface.py mode change 120000 => 100755 egs/multi_zh-hans/ASR/zipformer/export-onnx-streaming.py mode change 120000 => 100755 egs/multi_zh-hans/ASR/zipformer/export-onnx.py mode change 120000 => 100755 egs/multi_zh-hans/ASR/zipformer/export.py mode change 120000 => 100755 egs/multi_zh-hans/ASR/zipformer/jit_pretrained.py mode change 120000 => 100755 egs/multi_zh-hans/ASR/zipformer/jit_pretrained_streaming.py mode change 120000 => 100644 egs/multi_zh-hans/ASR/zipformer/joiner.py mode change 120000 => 100644 egs/multi_zh-hans/ASR/zipformer/model.py mode change 120000 => 100755 egs/multi_zh-hans/ASR/zipformer/onnx_check.py mode change 120000 => 100755 egs/multi_zh-hans/ASR/zipformer/onnx_pretrained-streaming.py mode change 120000 => 100755 egs/multi_zh-hans/ASR/zipformer/onnx_pretrained.py mode change 120000 => 100644 egs/multi_zh-hans/ASR/zipformer/optim.py mode change 120000 => 100755 egs/multi_zh-hans/ASR/zipformer/pretrained.py mode change 120000 => 100644 egs/multi_zh-hans/ASR/zipformer/scaling.py mode change 120000 => 100644 egs/multi_zh-hans/ASR/zipformer/scaling_converter.py mode change 120000 => 100644 egs/multi_zh-hans/ASR/zipformer/streaming_beam_search.py mode change 120000 => 100644 egs/multi_zh-hans/ASR/zipformer/subsampling.py mode change 120000 => 100644 egs/multi_zh-hans/ASR/zipformer/zipformer.py diff --git a/egs/multi_zh-hans/ASR/prepare.sh b/egs/multi_zh-hans/ASR/prepare.sh index 2bbbc41c3..e5ac5d3d7 100755 --- a/egs/multi_zh-hans/ASR/prepare.sh +++ b/egs/multi_zh-hans/ASR/prepare.sh @@ -120,7 +120,9 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_feats_train) . ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_feats_dev) . ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_feats_test) . - ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train.jsonl.gz) . + ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_L.jsonl.gz) . + ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_M.jsonl.gz) . + ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_S.jsonl.gz) . ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_dev.jsonl.gz) . ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_test.jsonl.gz) . cd ../.. @@ -236,12 +238,11 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then if [ -e ../../wenetspeech/ASR/data/fbank/.preprocess_complete ]; then cd data/fbank ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_DEV.jsonl.gz) . - ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_DEV_raw.jsonl.gz) . - ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_L_raw.jsonl.gz) . - ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_M_raw.jsonl.gz) . - ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_S_raw.jsonl.gz) . - ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_MEETING_raw.jsonl.gz) . - ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_NET_raw.jsonl.gz) . + ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_L.jsonl.gz) . + ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_M.jsonl.gz) . + ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_S.jsonl.gz) . + ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_MEETING.jsonl.gz) . + ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_NET.jsonl.gz) . cd ../.. else log "Abort! Please run ../../wenetspeech/ASR/prepare.sh" @@ -324,4 +325,4 @@ if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then done ./local/train_bpe_model.py --lang-dir ./data/lang_bpe_${vocab_size} -fi \ No newline at end of file +fi diff --git a/egs/multi_zh-hans/ASR/zipformer/__init__.py b/egs/multi_zh-hans/ASR/zipformer/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py b/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py index 746b212ff..c47964b07 100644 --- a/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py +++ b/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py @@ -1,4 +1,5 @@ # Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -20,18 +21,11 @@ import inspect import logging from functools import lru_cache from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional import torch -from lhotse import ( - CutSet, - Fbank, - FbankConfig, - load_manifest, - load_manifest_lazy, - set_caching_enabled, -) -from lhotse.dataset import ( +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures CutConcatenate, CutMix, DynamicBucketingSampler, @@ -40,7 +34,10 @@ from lhotse.dataset import ( SingleCutSampler, SpecAugment, ) -from lhotse.dataset.input_strategies import OnTheFlyFeatures +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) from lhotse.utils import fix_random_seed from torch.utils.data import DataLoader @@ -55,12 +52,13 @@ class _SeedWorkers: fix_random_seed(self.seed + worker_id) -class WenetSpeechAsrDataModule: +class LibriSpeechAsrDataModule: """ DataModule for k2 ASR experiments. It assumes there is always one train and valid dataloader, but there can be multiple test dataloaders (e.g. LibriSpeech test-clean and test-other). + It contains all the common data pipeline modules used in ASR experiments, e.g.: - dynamic batch size, @@ -68,6 +66,7 @@ class WenetSpeechAsrDataModule: - cut concatenation, - augmentation, - on-the-fly feature extraction + This class should be derived for specific corpora used in ASR tasks. """ @@ -83,6 +82,20 @@ class WenetSpeechAsrDataModule: "effective batch sizes, sampling strategies, applied data " "augmentations, etc.", ) + group.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="""Used only when --mini-libri is False.When enabled, + use 960h LibriSpeech. Otherwise, use 100h subset.""", + ) + group.add_argument( + "--mini-libri", + type=str2bool, + default=False, + help="True for mini librispeech", + ) + group.add_argument( "--manifest-dir", type=Path, @@ -147,6 +160,12 @@ class WenetSpeechAsrDataModule: help="When enabled (=default), the examples will be " "shuffled for each epoch.", ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) group.add_argument( "--return-cuts", type=str2bool, @@ -190,10 +209,10 @@ class WenetSpeechAsrDataModule: ) group.add_argument( - "--training-subset", + "--input-strategy", type=str, - default="L", - help="The training subset for using", + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", ) def train_dataloaders( @@ -208,12 +227,11 @@ class WenetSpeechAsrDataModule: sampler_state_dict: The state dict for the training sampler. """ - logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") - transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) @@ -262,6 +280,7 @@ class WenetSpeechAsrDataModule: logging.info("About to create train dataset") train = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), cut_transforms=transforms, input_transforms=input_transforms, return_cuts=self.args.return_cuts, @@ -292,8 +311,7 @@ class WenetSpeechAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, - buffer_size=300000, - drop_last=True, + drop_last=self.args.drop_last, ) else: logging.info("Using SingleCutSampler.") @@ -304,6 +322,10 @@ class WenetSpeechAsrDataModule: ) logging.info("About to create train dataloader") + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + # 'seed' is derived from the current random state, which will have # previously been set in the main process. seed = torch.randint(0, 100000, ()).item() @@ -318,10 +340,6 @@ class WenetSpeechAsrDataModule: worker_init_fn=worker_init_fn, ) - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_dl.sampler.load_state_dict(sampler_state_dict) - return train_dl def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: @@ -345,30 +363,28 @@ class WenetSpeechAsrDataModule: cut_transforms=transforms, return_cuts=self.args.return_cuts, ) - valid_sampler = DynamicBucketingSampler( cuts_valid, max_duration=self.args.max_duration, shuffle=False, ) logging.info("About to create dev dataloader") - valid_dl = DataLoader( validate, - batch_size=None, sampler=valid_sampler, - num_workers=self.args.num_workers, + batch_size=None, + num_workers=2, persistent_workers=False, ) return valid_dl def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.info("About to create test dataset") + logging.debug("About to create test dataset") test = K2SpeechRecognitionDataset( input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) if self.args.on_the_fly_feats - else PrecomputedFeatures(), + else eval(self.args.input_strategy)(), return_cuts=self.args.return_cuts, ) sampler = DynamicBucketingSampler( @@ -376,7 +392,7 @@ class WenetSpeechAsrDataModule: max_duration=self.args.max_duration, shuffle=False, ) - + logging.debug("About to create test dataloader") test_dl = DataLoader( test, batch_size=None, @@ -386,24 +402,74 @@ class WenetSpeechAsrDataModule: return test_dl @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") - cuts_train = load_manifest_lazy( - self.args.manifest_dir / f"cuts_{self.args.training_subset}.jsonl.gz" + def train_clean_5_cuts(self) -> CutSet: + logging.info("mini_librispeech: About to get train-clean-5 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz" ) - return cuts_train @lru_cache() - def valid_cuts(self) -> CutSet: - logging.info("About to get dev cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") + def train_clean_100_cuts(self) -> CutSet: + logging.info("About to get train-clean-100 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz" + ) @lru_cache() - def test_net_cuts(self) -> List[CutSet]: - logging.info("About to get TEST_NET cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_NET.jsonl.gz") + def train_clean_360_cuts(self) -> CutSet: + logging.info("About to get train-clean-360 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz" + ) @lru_cache() - def test_meeting_cuts(self) -> List[CutSet]: - logging.info("About to get TEST_MEETING cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz") + def train_other_500_cuts(self) -> CutSet: + logging.info("About to get train-other-500 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz" + ) + + @lru_cache() + def train_all_shuf_cuts(self) -> CutSet: + logging.info( + "About to get the shuffled train-clean-100, \ + train-clean-360 and train-other-500 cuts" + ) + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" + ) + + @lru_cache() + def dev_clean_2_cuts(self) -> CutSet: + logging.info("mini_librispeech: About to get dev-clean-2 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz" + ) + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz" + ) + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" + ) diff --git a/egs/multi_zh-hans/ASR/zipformer/beam_search.py b/egs/multi_zh-hans/ASR/zipformer/beam_search.py deleted file mode 120000 index e24eca39f..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/beam_search.py b/egs/multi_zh-hans/ASR/zipformer/beam_search.py new file mode 100644 index 000000000..fd59d4b7f --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/beam_search.py @@ -0,0 +1,2907 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple, Union + +import k2 +import sentencepiece as spm +import torch +from torch import nn + +from icefall import ContextGraph, ContextState, NgramLm, NgramLmStateCost +from icefall.decode import Nbest, one_best_decoding +from icefall.lm_wrapper import LmScorer +from icefall.rnn_lm.model import RnnLmModel +from icefall.transformer_lm.model import TransformerLM +from icefall.utils import ( + DecodingResults, + add_eos, + add_sos, + get_texts, + get_texts_with_timestamp, +) + + +def fast_beam_search_one_best( + model: nn.Module, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + temperature: float = 1.0, + ilme_scale: float = 0.0, + blank_penalty: float = 0.0, + return_timestamps: bool = False, + allow_partial: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """It limits the maximum number of symbols per frame to 1. + + A lattice is first obtained using fast beam search, and then + the shortest path within the lattice is used as the final output. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + temperature: + Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + temperature=temperature, + ilme_scale=ilme_scale, + allow_partial=allow_partial, + blank_penalty=blank_penalty, + ) + + best_path = one_best_decoding(lattice) + + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) + + +def fast_beam_search_nbest_LG( + model: nn.Module, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + num_paths: int, + nbest_scale: float = 0.5, + use_double_scores: bool = True, + temperature: float = 1.0, + blank_penalty: float = 0.0, + ilme_scale: float = 0.0, + return_timestamps: bool = False, + allow_partial: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """It limits the maximum number of symbols per frame to 1. + + The process to get the results is: + - (1) Use fast beam search to get a lattice + - (2) Select `num_paths` paths from the lattice using k2.random_paths() + - (3) Unique the selected paths + - (4) Intersect the selected paths with the lattice and compute the + shortest path from the intersection result + - (5) The path with the largest score is used as the decoding output. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + num_paths: + Number of paths to extract from the decoded lattice. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + use_double_scores: + True to use double precision for computation. False to use + single precision. + temperature: + Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + temperature=temperature, + allow_partial=allow_partial, + blank_penalty=blank_penalty, + ilme_scale=ilme_scale, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + + # The following code is modified from nbest.intersect() + word_fsa = k2.invert(nbest.fsa) + if hasattr(lattice, "aux_labels"): + # delete token IDs as it is not needed + del word_fsa.aux_labels + word_fsa.scores.zero_() + word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa) + path_to_utt_map = nbest.shape.row_ids(1) + + if hasattr(lattice, "aux_labels"): + # lattice has token IDs as labels and word IDs as aux_labels. + # inv_lattice has word IDs as labels and token IDs as aux_labels + inv_lattice = k2.invert(lattice) + inv_lattice = k2.arc_sort(inv_lattice) + else: + inv_lattice = k2.arc_sort(lattice) + + if inv_lattice.shape[0] == 1: + path_lattice = k2.intersect_device( + inv_lattice, + word_fsa_with_epsilon_loops, + b_to_a_map=torch.zeros_like(path_to_utt_map), + sorted_match_a=True, + ) + else: + path_lattice = k2.intersect_device( + inv_lattice, + word_fsa_with_epsilon_loops, + b_to_a_map=path_to_utt_map, + sorted_match_a=True, + ) + + # path_lattice has word IDs as labels and token IDs as aux_labels + path_lattice = k2.top_sort(k2.connect(path_lattice)) + tot_scores = path_lattice.get_tot_scores( + use_double_scores=use_double_scores, + log_semiring=True, # Note: we always use True + ) + # See https://github.com/k2-fsa/icefall/pull/420 for why + # we always use log_semiring=True + + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + best_hyp_indexes = ragged_tot_scores.argmax() + best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes) + + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) + + +def fast_beam_search_nbest( + model: nn.Module, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + num_paths: int, + nbest_scale: float = 0.5, + use_double_scores: bool = True, + temperature: float = 1.0, + blank_penalty: float = 0.0, + return_timestamps: bool = False, + allow_partial: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """It limits the maximum number of symbols per frame to 1. + + The process to get the results is: + - (1) Use fast beam search to get a lattice + - (2) Select `num_paths` paths from the lattice using k2.random_paths() + - (3) Unique the selected paths + - (4) Intersect the selected paths with the lattice and compute the + shortest path from the intersection result + - (5) The path with the largest score is used as the decoding output. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + num_paths: + Number of paths to extract from the decoded lattice. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + use_double_scores: + True to use double precision for computation. False to use + single precision. + temperature: + Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + blank_penalty=blank_penalty, + temperature=temperature, + allow_partial=allow_partial, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + + # at this point, nbest.fsa.scores are all zeros. + + nbest = nbest.intersect(lattice) + # Now nbest.fsa.scores contains acoustic scores + + max_indexes = nbest.tot_scores().argmax() + + best_path = k2.index_fsa(nbest.fsa, max_indexes) + + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) + + +def fast_beam_search_nbest_oracle( + model: nn.Module, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + num_paths: int, + ref_texts: List[List[int]], + use_double_scores: bool = True, + nbest_scale: float = 0.5, + temperature: float = 1.0, + blank_penalty: float = 0.0, + return_timestamps: bool = False, + allow_partial: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """It limits the maximum number of symbols per frame to 1. + + A lattice is first obtained using fast beam search, and then + we select `num_paths` linear paths from the lattice. The path + that has the minimum edit distance with the given reference transcript + is used as the output. + + This is the best result we can achieve for any nbest based rescoring + methods. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + num_paths: + Number of paths to extract from the decoded lattice. + ref_texts: + A list-of-list of integers containing the reference transcripts. + If the decoding_graph is a trivial_graph, the integer ID is the + BPE token ID. + use_double_scores: + True to use double precision for computation. False to use + single precision. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + temperature: + Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + temperature=temperature, + allow_partial=allow_partial, + blank_penalty=blank_penalty, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + + hyps = nbest.build_levenshtein_graphs() + refs = k2.levenshtein_graph(ref_texts, device=hyps.device) + + levenshtein_alignment = k2.levenshtein_alignment( + refs=refs, + hyps=hyps, + hyp_to_ref_map=nbest.shape.row_ids(1), + sorted_match_ref=True, + ) + + tot_scores = levenshtein_alignment.get_tot_scores( + use_double_scores=False, log_semiring=False + ) + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + + max_indexes = ragged_tot_scores.argmax() + + best_path = k2.index_fsa(nbest.fsa, max_indexes) + + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) + + +def fast_beam_search( + model: nn.Module, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + temperature: float = 1.0, + subtract_ilme: bool = False, + ilme_scale: float = 0.1, + allow_partial: bool = False, + blank_penalty: float = 0.0, +) -> k2.Fsa: + """It limits the maximum number of symbols per frame to 1. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + temperature: + Softmax temperature. + Returns: + Return an FsaVec with axes [utt][state][arc] containing the decoded + lattice. Note: When the input graph is a TrivialGraph, the returned + lattice is actually an acceptor. + """ + assert encoder_out.ndim == 3 + + context_size = model.decoder.context_size + vocab_size = model.decoder.vocab_size + + B, T, C = encoder_out.shape + + config = k2.RnntDecodingConfig( + vocab_size=vocab_size, + decoder_history_len=context_size, + beam=beam, + max_contexts=max_contexts, + max_states=max_states, + ) + individual_streams = [] + for i in range(B): + individual_streams.append(k2.RnntDecodingStream(decoding_graph)) + decoding_streams = k2.RnntDecodingStreams(individual_streams, config) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + for t in range(T): + # shape is a RaggedShape of shape (B, context) + # contexts is a Tensor of shape (shape.NumElements(), context_size) + shape, contexts = decoding_streams.get_contexts() + # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 + contexts = contexts.to(torch.int64) + # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) + decoder_out = model.decoder(contexts, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + # current_encoder_out is of shape + # (shape.NumElements(), 1, joiner_dim) + # fmt: off + current_encoder_out = torch.index_select( + encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) + ) + # fmt: on + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + logits = logits.squeeze(1).squeeze(1) + + if blank_penalty != 0: + logits[:, 0] -= blank_penalty + + log_probs = (logits / temperature).log_softmax(dim=-1) + + if ilme_scale != 0: + ilme_logits = model.joiner( + torch.zeros_like( + current_encoder_out, device=current_encoder_out.device + ).unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + ilme_logits = ilme_logits.squeeze(1).squeeze(1) + if blank_penalty != 0: + ilme_logits[:, 0] -= blank_penalty + ilme_log_probs = (ilme_logits / temperature).log_softmax(dim=-1) + log_probs -= ilme_scale * ilme_log_probs + + decoding_streams.advance(log_probs) + decoding_streams.terminate_and_flush_to_streams() + lattice = decoding_streams.format_output( + encoder_out_lens.tolist(), allow_partial=allow_partial + ) + + return lattice + + +def greedy_search( + model: nn.Module, + encoder_out: torch.Tensor, + max_sym_per_frame: int, + blank_penalty: float = 0.0, + return_timestamps: bool = False, +) -> Union[List[int], DecodingResults]: + """Greedy search for a single utterance. + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + max_sym_per_frame: + Maximum number of symbols per frame. If it is set to 0, the WER + would be 100%. + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + unk_id = getattr(model, "unk_id", blank_id) + + device = next(model.parameters()).device + + decoder_input = torch.tensor( + [-1] * (context_size - 1) + [blank_id], device=device, dtype=torch.int64 + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + T = encoder_out.size(1) + t = 0 + hyp = [blank_id] * context_size + + # timestamp[i] is the frame index after subsampling + # on which hyp[i] is decoded + timestamp = [] + + # Maximum symbols per utterance. + max_sym_per_utt = 1000 + + # symbols per frame + sym_per_frame = 0 + + # symbols per utterance decoded so far + sym_per_utt = 0 + + while t < T and sym_per_utt < max_sym_per_utt: + if sym_per_frame >= max_sym_per_frame: + sym_per_frame = 0 + t += 1 + continue + + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) + # fmt: on + logits = model.joiner( + current_encoder_out, decoder_out.unsqueeze(1), project_input=False + ) + # logits is (1, 1, 1, vocab_size) + + if blank_penalty != 0: + logits[:, :, :, 0] -= blank_penalty + + y = logits.argmax().item() + if y not in (blank_id, unk_id): + hyp.append(y) + timestamp.append(t) + decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( + 1, context_size + ) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + sym_per_utt += 1 + sym_per_frame += 1 + else: + sym_per_frame = 0 + t += 1 + hyp = hyp[context_size:] # remove blanks + + if not return_timestamps: + return hyp + else: + return DecodingResults( + hyps=[hyp], + timestamps=[timestamp], + ) + + +def greedy_search_batch( + model: nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + blank_penalty: float = 0, + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C), where N >= 1. + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + device = next(model.parameters()).device + + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + hyps = [[-1] * (context_size - 1) + [blank_id] for _ in range(N)] + + # timestamp[n][i] is the frame index after subsampling + # on which hyp[n][i] is decoded + timestamps = [[] for _ in range(N)] + # scores[n][i] is the logits on which hyp[n][i] is decoded + scores = [[] for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out: (N, 1, decoder_out_dim) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + for (t, batch_size) in enumerate(batch_size_list): + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + + logits = model.joiner( + current_encoder_out, decoder_out.unsqueeze(1), project_input=False + ) + # logits'shape (batch_size, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) + assert logits.ndim == 2, logits.shape + + if blank_penalty != 0: + logits[:, 0] -= blank_penalty + + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v not in (blank_id, unk_id): + hyps[i].append(v) + timestamps[i].append(t) + scores[i].append(logits[i, v].item()) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + ans_timestamps = [] + ans_scores = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + ans_timestamps.append(timestamps[unsorted_indices[i]]) + ans_scores.append(scores[unsorted_indices[i]]) + + if not return_timestamps: + return ans + else: + return DecodingResults( + hyps=ans, + timestamps=ans_timestamps, + scores=ans_scores, + ) + + +@dataclass +class Hypothesis: + # The predicted tokens so far. + # Newly predicted tokens are appended to `ys`. + ys: List[int] + + # The log prob of ys. + # It contains only one entry. + log_prob: torch.Tensor + + # timestamp[i] is the frame index after subsampling + # on which ys[i] is decoded + timestamp: List[int] = field(default_factory=list) + + # the lm score for next token given the current ys + lm_score: Optional[torch.Tensor] = None + + # the RNNLM states (h and c in LSTM) + state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + + # N-gram LM state + state_cost: Optional[NgramLmStateCost] = None + + # Context graph state + context_state: Optional[ContextState] = None + + @property + def key(self) -> str: + """Return a string representation of self.ys""" + return "_".join(map(str, self.ys)) + + +class HypothesisList(object): + def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None: + """ + Args: + data: + A dict of Hypotheses. Its key is its `value.key`. + """ + if data is None: + self._data = {} + else: + self._data = data + + @property + def data(self) -> Dict[str, Hypothesis]: + return self._data + + def add(self, hyp: Hypothesis) -> None: + """Add a Hypothesis to `self`. + + If `hyp` already exists in `self`, its probability is updated using + `log-sum-exp` with the existed one. + + Args: + hyp: + The hypothesis to be added. + """ + key = hyp.key + if key in self: + old_hyp = self._data[key] # shallow copy + torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) + else: + self._data[key] = hyp + + def get_most_probable(self, length_norm: bool = False) -> Hypothesis: + """Get the most probable hypothesis, i.e., the one with + the largest `log_prob`. + + Args: + length_norm: + If True, the `log_prob` of a hypothesis is normalized by the + number of tokens in it. + Returns: + Return the hypothesis that has the largest `log_prob`. + """ + if length_norm: + return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) + else: + return max(self._data.values(), key=lambda hyp: hyp.log_prob) + + def remove(self, hyp: Hypothesis) -> None: + """Remove a given hypothesis. + + Caution: + `self` is modified **in-place**. + + Args: + hyp: + The hypothesis to be removed from `self`. + Note: It must be contained in `self`. Otherwise, + an exception is raised. + """ + key = hyp.key + assert key in self, f"{key} does not exist" + del self._data[key] + + def filter(self, threshold: torch.Tensor) -> "HypothesisList": + """Remove all Hypotheses whose log_prob is less than threshold. + + Caution: + `self` is not modified. Instead, a new HypothesisList is returned. + + Returns: + Return a new HypothesisList containing all hypotheses from `self` + with `log_prob` being greater than the given `threshold`. + """ + ans = HypothesisList() + for _, hyp in self._data.items(): + if hyp.log_prob > threshold: + ans.add(hyp) # shallow copy + return ans + + def topk(self, k: int, length_norm: bool = False) -> "HypothesisList": + """Return the top-k hypothesis. + + Args: + length_norm: + If True, the `log_prob` of a hypothesis is normalized by the + number of tokens in it. + """ + hyps = list(self._data.items()) + + if length_norm: + hyps = sorted( + hyps, key=lambda h: h[1].log_prob / len(h[1].ys), reverse=True + )[:k] + else: + hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] + + ans = HypothesisList(dict(hyps)) + return ans + + def __contains__(self, key: str): + return key in self._data + + def __iter__(self): + return iter(self._data.values()) + + def __len__(self) -> int: + return len(self._data) + + def __str__(self) -> str: + s = [] + for key in self: + s.append(key) + return ", ".join(s) + + +def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: + """Return a ragged shape with axes [utt][num_hyps]. + + Args: + hyps: + len(hyps) == batch_size. It contains the current hypothesis for + each utterance in the batch. + Returns: + Return a ragged shape with 2 axes [utt][num_hyps]. Note that + the shape is on CPU. + """ + num_hyps = [len(h) for h in hyps] + + # torch.cumsum() is inclusive sum, so we put a 0 at the beginning + # to get exclusive sum later. + num_hyps.insert(0, 0) + + num_hyps = torch.tensor(num_hyps) + row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32) + ans = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=row_splits[-1].item() + ) + return ans + + +def modified_beam_search( + model: nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + context_graph: Optional[ContextGraph] = None, + beam: int = 4, + temperature: float = 1.0, + blank_penalty: float = 0.0, + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + beam: + Number of active paths during the beam search. + temperature: + Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + context_state=None if context_graph is None else context_graph.root, + timestamp=[], + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for (t, batch_size) in enumerate(batch_size_list): + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + if blank_penalty != 0: + logits[:, 0] -= blank_penalty + + log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + new_timestamp = hyp.timestamp[:] + context_score = 0 + new_context_state = None if context_graph is None else hyp.context_state + if new_token not in (blank_id, unk_id): + new_ys.append(new_token) + new_timestamp.append(t) + if context_graph is not None: + ( + context_score, + new_context_state, + ) = context_graph.forward_one_step(hyp.context_state, new_token) + + new_log_prob = topk_log_probs[k] + context_score + + new_hyp = Hypothesis( + ys=new_ys, + log_prob=new_log_prob, + timestamp=new_timestamp, + context_state=new_context_state, + ) + B[i].add(new_hyp) + + B = B + finalized_B + + # finalize context_state, if the matched contexts do not reach final state + # we need to add the score on the corresponding backoff arc + if context_graph is not None: + finalized_B = [HypothesisList() for _ in range(len(B))] + for i, hyps in enumerate(B): + for hyp in list(hyps): + context_score, new_context_state = context_graph.finalize( + hyp.context_state + ) + finalized_B[i].add( + Hypothesis( + ys=hyp.ys, + log_prob=hyp.log_prob + context_score, + timestamp=hyp.timestamp, + context_state=new_context_state, + ) + ) + B = finalized_B + + best_hyps = [b.get_most_probable(length_norm=True) for b in B] + + sorted_ans = [h.ys[context_size:] for h in best_hyps] + sorted_timestamps = [h.timestamp for h in best_hyps] + ans = [] + ans_timestamps = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) + + if not return_timestamps: + return ans + else: + return DecodingResults( + hyps=ans, + timestamps=ans_timestamps, + ) + + +def modified_beam_search_lm_rescore( + model: nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + LM: LmScorer, + lm_scale_list: List[int], + beam: int = 4, + temperature: float = 1.0, + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + Rescore the final results with RNNLM and return the one with the highest score + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + beam: + Number of active paths during the beam search. + temperature: + Softmax temperature. + LM: + A neural network language model + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + timestamp=[], + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for (t, batch_size) in enumerate(batch_size_list): + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + new_timestamp = hyp.timestamp[:] + if new_token not in (blank_id, unk_id): + new_ys.append(new_token) + new_timestamp.append(t) + + new_log_prob = topk_log_probs[k] + new_hyp = Hypothesis( + ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp + ) + B[i].add(new_hyp) + + B = B + finalized_B + + # get the am_scores for n-best list + hyps_shape = get_hyps_shape(B) + am_scores = torch.tensor([hyp.log_prob.item() for b in B for hyp in b]) + am_scores = k2.RaggedTensor(value=am_scores, shape=hyps_shape).to(device) + + # now LM rescore + # prepare input data to LM + candidate_seqs = [hyp.ys[context_size:] for b in B for hyp in b] + possible_seqs = k2.RaggedTensor(candidate_seqs) + row_splits = possible_seqs.shape.row_splits(1) + sentence_token_lengths = row_splits[1:] - row_splits[:-1] + possible_seqs_with_sos = add_sos(possible_seqs, sos_id=1) + possible_seqs_with_eos = add_eos(possible_seqs, eos_id=1) + sentence_token_lengths += 1 + + x = possible_seqs_with_sos.pad(mode="constant", padding_value=blank_id) + y = possible_seqs_with_eos.pad(mode="constant", padding_value=blank_id) + x = x.to(device).to(torch.int64) + y = y.to(device).to(torch.int64) + sentence_token_lengths = sentence_token_lengths.to(device).to(torch.int64) + + lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths) + assert lm_scores.ndim == 2 + lm_scores = -1 * lm_scores.sum(dim=1) + + ans = {} + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + + # get the best hyp with different lm_scale + for lm_scale in lm_scale_list: + key = f"nnlm_scale_{lm_scale:.2f}" + tot_scores = am_scores.values + lm_scores * lm_scale + ragged_tot_scores = k2.RaggedTensor(shape=am_scores.shape, value=tot_scores) + max_indexes = ragged_tot_scores.argmax().tolist() + unsorted_hyps = [candidate_seqs[idx] for idx in max_indexes] + hyps = [] + for idx in unsorted_indices: + hyps.append(unsorted_hyps[idx]) + + ans[key] = hyps + return ans + + +def modified_beam_search_lm_rescore_LODR( + model: nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + LM: LmScorer, + LODR_lm: NgramLm, + sp: spm.SentencePieceProcessor, + lm_scale_list: List[int], + beam: int = 4, + temperature: float = 1.0, + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + Rescore the final results with RNNLM and return the one with the highest score + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + beam: + Number of active paths during the beam search. + temperature: + Softmax temperature. + LM: + A neural network language model + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + timestamp=[], + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for (t, batch_size) in enumerate(batch_size_list): + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + new_timestamp = hyp.timestamp[:] + if new_token not in (blank_id, unk_id): + new_ys.append(new_token) + new_timestamp.append(t) + + new_log_prob = topk_log_probs[k] + new_hyp = Hypothesis( + ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp + ) + B[i].add(new_hyp) + + B = B + finalized_B + + # get the am_scores for n-best list + hyps_shape = get_hyps_shape(B) + am_scores = torch.tensor([hyp.log_prob.item() for b in B for hyp in b]) + am_scores = k2.RaggedTensor(value=am_scores, shape=hyps_shape).to(device) + + # now LM rescore + # prepare input data to LM + candidate_seqs = [hyp.ys[context_size:] for b in B for hyp in b] + possible_seqs = k2.RaggedTensor(candidate_seqs) + row_splits = possible_seqs.shape.row_splits(1) + sentence_token_lengths = row_splits[1:] - row_splits[:-1] + possible_seqs_with_sos = add_sos(possible_seqs, sos_id=1) + possible_seqs_with_eos = add_eos(possible_seqs, eos_id=1) + sentence_token_lengths += 1 + + x = possible_seqs_with_sos.pad(mode="constant", padding_value=blank_id) + y = possible_seqs_with_eos.pad(mode="constant", padding_value=blank_id) + x = x.to(device).to(torch.int64) + y = y.to(device).to(torch.int64) + sentence_token_lengths = sentence_token_lengths.to(device).to(torch.int64) + + lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths) + assert lm_scores.ndim == 2 + lm_scores = -1 * lm_scores.sum(dim=1) + + # now LODR scores + import math + + LODR_scores = [] + for seq in candidate_seqs: + tokens = " ".join(sp.id_to_piece(seq)) + LODR_scores.append(LODR_lm.score(tokens)) + LODR_scores = torch.tensor(LODR_scores).to(device) * math.log( + 10 + ) # arpa scores are 10-based + assert lm_scores.shape == LODR_scores.shape + + ans = {} + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + + LODR_scale_list = [0.05 * i for i in range(1, 20)] + # get the best hyp with different lm_scale and lodr_scale + for lm_scale in lm_scale_list: + for lodr_scale in LODR_scale_list: + key = f"nnlm_scale_{lm_scale:.2f}_lodr_scale_{lodr_scale:.2f}" + tot_scores = ( + am_scores.values / lm_scale + lm_scores - LODR_scores * lodr_scale + ) + ragged_tot_scores = k2.RaggedTensor(shape=am_scores.shape, value=tot_scores) + max_indexes = ragged_tot_scores.argmax().tolist() + unsorted_hyps = [candidate_seqs[idx] for idx in max_indexes] + hyps = [] + for idx in unsorted_indices: + hyps.append(unsorted_hyps[idx]) + + ans[key] = hyps + return ans + + +def _deprecated_modified_beam_search( + model: nn.Module, + encoder_out: torch.Tensor, + beam: int = 4, + return_timestamps: bool = False, +) -> Union[List[int], DecodingResults]: + """It limits the maximum number of symbols per frame to 1. + + It decodes only one utterance at a time. We keep it only for reference. + The function :func:`modified_beam_search` should be preferred as it + supports batch decoding. + + + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + beam: + Beam size. + return_timestamps: + Whether to return timestamps. + + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + + device = next(model.parameters()).device + + T = encoder_out.size(1) + + B = HypothesisList() + B.add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + timestamp=[], + ) + ) + encoder_out = model.joiner.encoder_proj(encoder_out) + + for t in range(T): + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) + # current_encoder_out is of shape (1, 1, 1, encoder_out_dim) + # fmt: on + A = list(B) + B = HypothesisList() + + ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A]) + # ys_log_probs is of shape (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyp in A], + device=device, + dtype=torch.int64, + ) + # decoder_input is of shape (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_output is of shape (num_hyps, 1, 1, joiner_dim) + + current_encoder_out = current_encoder_out.expand( + decoder_out.size(0), 1, 1, -1 + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) + # logits is of shape (num_hyps, 1, 1, vocab_size) + logits = logits.squeeze(1).squeeze(1) + + # now logits is of shape (num_hyps, vocab_size) + log_probs = logits.log_softmax(dim=-1) + + log_probs.add_(ys_log_probs) + + log_probs = log_probs.reshape(-1) + topk_log_probs, topk_indexes = log_probs.topk(beam) + + # topk_hyp_indexes are indexes into `A` + topk_hyp_indexes = topk_indexes // logits.size(-1) + topk_token_indexes = topk_indexes % logits.size(-1) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = topk_hyp_indexes.tolist() + topk_token_indexes = topk_token_indexes.tolist() + + for i in range(len(topk_hyp_indexes)): + hyp = A[topk_hyp_indexes[i]] + new_ys = hyp.ys[:] + new_timestamp = hyp.timestamp[:] + new_token = topk_token_indexes[i] + if new_token not in (blank_id, unk_id): + new_ys.append(new_token) + new_timestamp.append(t) + new_log_prob = topk_log_probs[i] + new_hyp = Hypothesis( + ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp + ) + B.add(new_hyp) + + best_hyp = B.get_most_probable(length_norm=True) + ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks + + if not return_timestamps: + return ys + else: + return DecodingResults(hyps=[ys], timestamps=[best_hyp.timestamp]) + + +def beam_search( + model: nn.Module, + encoder_out: torch.Tensor, + beam: int = 4, + temperature: float = 1.0, + blank_penalty: float = 0.0, + return_timestamps: bool = False, +) -> Union[List[int], DecodingResults]: + """ + It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf + + espnet/nets/beam_search_transducer.py#L247 is used as a reference. + + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + beam: + Beam size. + temperature: + Softmax temperature. + return_timestamps: + Whether to return timestamps. + + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + + device = next(model.parameters()).device + + decoder_input = torch.tensor( + [blank_id] * context_size, + device=device, + dtype=torch.int64, + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + T = encoder_out.size(1) + t = 0 + + B = HypothesisList() + B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0, timestamp=[])) + + max_sym_per_utt = 20000 + + sym_per_utt = 0 + + decoder_cache: Dict[str, torch.Tensor] = {} + + while t < T and sym_per_utt < max_sym_per_utt: + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) + # fmt: on + A = B + B = HypothesisList() + + joint_cache: Dict[str, torch.Tensor] = {} + + # TODO(fangjun): Implement prefix search to update the `log_prob` + # of hypotheses in A + + while True: + y_star = A.get_most_probable() + A.remove(y_star) + + cached_key = y_star.key + + if cached_key not in decoder_cache: + decoder_input = torch.tensor( + [y_star.ys[-context_size:]], + device=device, + dtype=torch.int64, + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + decoder_cache[cached_key] = decoder_out + else: + decoder_out = decoder_cache[cached_key] + + cached_key += f"-t-{t}" + if cached_key not in joint_cache: + logits = model.joiner( + current_encoder_out, + decoder_out.unsqueeze(1), + project_input=False, + ) + + if blank_penalty != 0: + logits[:, :, :, 0] -= blank_penalty + + # TODO(fangjun): Scale the blank posterior + log_prob = (logits / temperature).log_softmax(dim=-1) + # log_prob is (1, 1, 1, vocab_size) + log_prob = log_prob.squeeze() + # Now log_prob is (vocab_size,) + joint_cache[cached_key] = log_prob + else: + log_prob = joint_cache[cached_key] + + # First, process the blank symbol + skip_log_prob = log_prob[blank_id] + new_y_star_log_prob = y_star.log_prob + skip_log_prob + + # ys[:] returns a copy of ys + B.add( + Hypothesis( + ys=y_star.ys[:], + log_prob=new_y_star_log_prob, + timestamp=y_star.timestamp[:], + ) + ) + + # Second, process other non-blank labels + values, indices = log_prob.topk(beam + 1) + for i, v in zip(indices.tolist(), values.tolist()): + if i in (blank_id, unk_id): + continue + new_ys = y_star.ys + [i] + new_log_prob = y_star.log_prob + v + new_timestamp = y_star.timestamp + [t] + A.add( + Hypothesis( + ys=new_ys, + log_prob=new_log_prob, + timestamp=new_timestamp, + ) + ) + + # Check whether B contains more than "beam" elements more probable + # than the most probable in A + A_most_probable = A.get_most_probable() + + kept_B = B.filter(A_most_probable.log_prob) + + if len(kept_B) >= beam: + B = kept_B.topk(beam) + break + + t += 1 + + best_hyp = B.get_most_probable(length_norm=True) + ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks + + if not return_timestamps: + return ys + else: + return DecodingResults(hyps=[ys], timestamps=[best_hyp.timestamp]) + + +def fast_beam_search_with_nbest_rescoring( + model: nn.Module, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + ngram_lm_scale_list: List[float], + num_paths: int, + G: k2.Fsa, + sp: spm.SentencePieceProcessor, + word_table: k2.SymbolTable, + oov_word: str = "", + use_double_scores: bool = True, + nbest_scale: float = 0.5, + temperature: float = 1.0, + return_timestamps: bool = False, +) -> Dict[str, Union[List[List[int]], DecodingResults]]: + """It limits the maximum number of symbols per frame to 1. + A lattice is first obtained using fast beam search, num_path are selected + and rescored using a given language model. The shortest path within the + lattice is used as the final output. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + ngram_lm_scale_list: + A list of floats representing LM score scales. + num_paths: + Number of paths to extract from the decoded lattice. + G: + An FsaVec containing only a single FSA. It is an n-gram LM. + sp: + The BPE model. + word_table: + The word symbol table. + oov_word: + OOV words are replaced with this word. + use_double_scores: + True to use double precision for computation. False to use + single precision. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + temperature: + Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: + Return the decoded result in a dict, where the key has the form + 'ngram_lm_scale_xx' and the value is the decoded results + optionally with timestamps. `xx` is the ngram LM scale value + used during decoding, i.e., 0.1. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + temperature=temperature, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + # at this point, nbest.fsa.scores are all zeros. + + nbest = nbest.intersect(lattice) + # Now nbest.fsa.scores contains acoustic scores + + am_scores = nbest.tot_scores() + + # Now we need to compute the LM scores of each path. + # (1) Get the token IDs of each Path. We assume the decoding_graph + # is an acceptor, i.e., lattice is also an acceptor + tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) # [path][arc] + + tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.labels.contiguous()) + tokens = tokens.remove_values_leq(0) # remove -1 and 0 + + token_list: List[List[int]] = tokens.tolist() + word_list: List[List[str]] = sp.decode(token_list) + + assert isinstance(oov_word, str), oov_word + assert oov_word in word_table, oov_word + oov_word_id = word_table[oov_word] + + word_ids_list: List[List[int]] = [] + + for words in word_list: + this_word_ids = [] + for w in words.split(): + if w in word_table: + this_word_ids.append(word_table[w]) + else: + this_word_ids.append(oov_word_id) + word_ids_list.append(this_word_ids) + + word_fsas = k2.linear_fsa(word_ids_list, device=lattice.device) + word_fsas_with_self_loops = k2.add_epsilon_self_loops(word_fsas) + + num_unique_paths = len(word_ids_list) + + b_to_a_map = torch.zeros( + num_unique_paths, + dtype=torch.int32, + device=lattice.device, + ) + + rescored_word_fsas = k2.intersect_device( + a_fsas=G, + b_fsas=word_fsas_with_self_loops, + b_to_a_map=b_to_a_map, + sorted_match_a=True, + ret_arc_maps=False, + ) + + rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas) + rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas)) + ngram_lm_scores = rescored_word_fsas.get_tot_scores( + use_double_scores=True, + log_semiring=False, + ) + + ans: Dict[str, Union[List[List[int]], DecodingResults]] = {} + for s in ngram_lm_scale_list: + key = f"ngram_lm_scale_{s}" + tot_scores = am_scores.values + s * ngram_lm_scores + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + max_indexes = ragged_tot_scores.argmax() + best_path = k2.index_fsa(nbest.fsa, max_indexes) + + if not return_timestamps: + ans[key] = get_texts(best_path) + else: + ans[key] = get_texts_with_timestamp(best_path) + + return ans + + +def fast_beam_search_with_nbest_rnn_rescoring( + model: nn.Module, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + ngram_lm_scale_list: List[float], + num_paths: int, + G: k2.Fsa, + sp: spm.SentencePieceProcessor, + word_table: k2.SymbolTable, + rnn_lm_model: torch.nn.Module, + rnn_lm_scale_list: List[float], + oov_word: str = "", + use_double_scores: bool = True, + nbest_scale: float = 0.5, + temperature: float = 1.0, + return_timestamps: bool = False, +) -> Dict[str, Union[List[List[int]], DecodingResults]]: + """It limits the maximum number of symbols per frame to 1. + A lattice is first obtained using fast beam search, num_path are selected + and rescored using a given language model and a rnn-lm. + The shortest path within the lattice is used as the final output. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + ngram_lm_scale_list: + A list of floats representing LM score scales. + num_paths: + Number of paths to extract from the decoded lattice. + G: + An FsaVec containing only a single FSA. It is an n-gram LM. + sp: + The BPE model. + word_table: + The word symbol table. + rnn_lm_model: + A rnn-lm model used for LM rescoring + rnn_lm_scale_list: + A list of floats representing RNN score scales. + oov_word: + OOV words are replaced with this word. + use_double_scores: + True to use double precision for computation. False to use + single precision. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + temperature: + Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: + Return the decoded result in a dict, where the key has the form + 'ngram_lm_scale_xx' and the value is the decoded results + optionally with timestamps. `xx` is the ngram LM scale value + used during decoding, i.e., 0.1. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + temperature=temperature, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + # at this point, nbest.fsa.scores are all zeros. + + nbest = nbest.intersect(lattice) + # Now nbest.fsa.scores contains acoustic scores + + am_scores = nbest.tot_scores() + + # Now we need to compute the LM scores of each path. + # (1) Get the token IDs of each Path. We assume the decoding_graph + # is an acceptor, i.e., lattice is also an acceptor + tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) # [path][arc] + + tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.labels.contiguous()) + tokens = tokens.remove_values_leq(0) # remove -1 and 0 + + token_list: List[List[int]] = tokens.tolist() + word_list: List[List[str]] = sp.decode(token_list) + + assert isinstance(oov_word, str), oov_word + assert oov_word in word_table, oov_word + oov_word_id = word_table[oov_word] + + word_ids_list: List[List[int]] = [] + + for words in word_list: + this_word_ids = [] + for w in words.split(): + if w in word_table: + this_word_ids.append(word_table[w]) + else: + this_word_ids.append(oov_word_id) + word_ids_list.append(this_word_ids) + + word_fsas = k2.linear_fsa(word_ids_list, device=lattice.device) + word_fsas_with_self_loops = k2.add_epsilon_self_loops(word_fsas) + + num_unique_paths = len(word_ids_list) + + b_to_a_map = torch.zeros( + num_unique_paths, + dtype=torch.int32, + device=lattice.device, + ) + + rescored_word_fsas = k2.intersect_device( + a_fsas=G, + b_fsas=word_fsas_with_self_loops, + b_to_a_map=b_to_a_map, + sorted_match_a=True, + ret_arc_maps=False, + ) + + rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas) + rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas)) + ngram_lm_scores = rescored_word_fsas.get_tot_scores( + use_double_scores=True, + log_semiring=False, + ) + + # Now RNN-LM + blank_id = model.decoder.blank_id + sos_id = sp.piece_to_id("sos_id") + eos_id = sp.piece_to_id("eos_id") + + sos_tokens = add_sos(tokens, sos_id) + tokens_eos = add_eos(tokens, eos_id) + sos_tokens_row_splits = sos_tokens.shape.row_splits(1) + sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1] + + x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id) + y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id) + + x_tokens = x_tokens.to(torch.int64) + y_tokens = y_tokens.to(torch.int64) + sentence_lengths = sentence_lengths.to(torch.int64) + + rnn_lm_nll = rnn_lm_model(x=x_tokens, y=y_tokens, lengths=sentence_lengths) + assert rnn_lm_nll.ndim == 2 + assert rnn_lm_nll.shape[0] == len(token_list) + rnn_lm_scores = -1 * rnn_lm_nll.sum(dim=1) + + ans: Dict[str, List[List[int]]] = {} + for n_scale in ngram_lm_scale_list: + for rnn_scale in rnn_lm_scale_list: + key = f"ngram_lm_scale_{n_scale}_rnn_lm_scale_{rnn_scale}" + tot_scores = ( + am_scores.values + n_scale * ngram_lm_scores + rnn_scale * rnn_lm_scores + ) + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + max_indexes = ragged_tot_scores.argmax() + best_path = k2.index_fsa(nbest.fsa, max_indexes) + + if not return_timestamps: + ans[key] = get_texts(best_path) + else: + ans[key] = get_texts_with_timestamp(best_path) + + return ans + + +def modified_beam_search_ngram_rescoring( + model: nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ngram_lm: NgramLm, + ngram_lm_scale: float, + beam: int = 4, + temperature: float = 1.0, +) -> List[List[int]]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + beam: + Number of active paths during the beam search. + temperature: + Softmax temperature. + Returns: + Return a list-of-list of token IDs. ans[i] is the decoding results + for the i-th utterance. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + lm_scale = ngram_lm_scale + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + state_cost=NgramLmStateCost(ngram_lm), + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [ + hyp.log_prob.reshape(1, 1) + hyp.state_cost.lm_score * lm_scale + for hyps in A + for hyp in hyps + ] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + vocab_size = log_probs.size(-1) + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + if new_token not in (blank_id, unk_id): + new_ys.append(new_token) + state_cost = hyp.state_cost.forward_one_step(new_token) + else: + state_cost = hyp.state_cost + + # We only keep AM scores in new_hyp.log_prob + new_log_prob = topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale + + new_hyp = Hypothesis( + ys=new_ys, log_prob=new_log_prob, state_cost=state_cost + ) + B[i].add(new_hyp) + + B = B + finalized_B + best_hyps = [b.get_most_probable(length_norm=True) for b in B] + + sorted_ans = [h.ys[context_size:] for h in best_hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +def modified_beam_search_LODR( + model: nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + LODR_lm: NgramLm, + LODR_lm_scale: float, + LM: LmScorer, + beam: int = 4, +) -> List[List[int]]: + """This function implements LODR (https://arxiv.org/abs/2203.16776) with + `modified_beam_search`. It uses a bi-gram language model as the estimate + of the internal language model and subtracts its score during shallow fusion + with an external language model. This implementation uses a RNNLM as the + external language model. + + Args: + model (Transducer): + The transducer model + encoder_out (torch.Tensor): + Encoder output in (N,T,C) + encoder_out_lens (torch.Tensor): + A 1-D tensor of shape (N,), containing the number of + valid frames in encoder_out before padding. + LODR_lm: + A low order n-gram LM, whose score will be subtracted during shallow fusion + LODR_lm_scale: + The scale of the LODR_lm + LM: + A neural net LM, e.g an RNNLM or transformer LM + beam (int, optional): + Beam size. Defaults to 4. + + Returns: + Return a list-of-list of token IDs. ans[i] is the decoding results + for the i-th utterance. + + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + assert LM is not None + lm_scale = LM.lm_scale + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + sos_id = getattr(LM, "sos_id", 1) + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + # get initial lm score and lm state by scoring the "sos" token + sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) + lens = torch.tensor([1]).to(device) + init_score, init_states = LM.score_token(sos_token, lens) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + state=init_states, # state of the NN LM + lm_score=init_score.reshape(-1), + state_cost=NgramLmStateCost( + LODR_lm + ), # state of the source domain ngram + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] # get batch + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + """ + for all hyps with a non-blank new token, score this token. + It is a little confusing here because this for-loop + looks very similar to the one below. Here, we go through all + top-k tokens and only add the non-blanks ones to the token_list. + LM will score those tokens given the LM states. Note that + the variable `scores` is the LM score after seeing the new + non-blank token. + """ + token_list = [] + hs = [] + cs = [] + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_token = topk_token_indexes[k] + if new_token not in (blank_id, unk_id): + if LM.lm_type == "rnn": + token_list.append([new_token]) + # store the LSTM states + hs.append(hyp.state[0]) + cs.append(hyp.state[1]) + else: + # for transformer LM + token_list.append( + [sos_id] + hyp.ys[context_size:] + [new_token] + ) + + # forward NN LM to get new states and scores + if len(token_list) != 0: + x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device) + if LM.lm_type == "rnn": + tokens_to_score = ( + torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) + ) + hs = torch.cat(hs, dim=1).to(device) + cs = torch.cat(cs, dim=1).to(device) + state = (hs, cs) + else: + # for transformer LM + tokens_list = [torch.tensor(tokens) for tokens in token_list] + tokens_to_score = ( + torch.nn.utils.rnn.pad_sequence( + tokens_list, batch_first=True, padding_value=0.0 + ) + .to(device) + .to(torch.int64) + ) + + state = None + + scores, lm_states = LM.score_token(tokens_to_score, x_lens, state) + + count = 0 # index, used to locate score and lm states + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + ys = hyp.ys[:] + + # current score of hyp + lm_score = hyp.lm_score + state = hyp.state + + hyp_log_prob = topk_log_probs[k] # get score of current hyp + new_token = topk_token_indexes[k] + if new_token not in (blank_id, unk_id): + + ys.append(new_token) + state_cost = hyp.state_cost.forward_one_step(new_token) + + # calculate the score of the latest token + current_ngram_score = state_cost.lm_score - hyp.state_cost.lm_score + + assert current_ngram_score <= 0.0, ( + state_cost.lm_score, + hyp.state_cost.lm_score, + ) + # score = score + TDLM_score - LODR_score + # LODR_LM_scale should be a negative number here + hyp_log_prob += ( + lm_score[new_token] * lm_scale + + LODR_lm_scale * current_ngram_score + ) # add the lm score + + lm_score = scores[count] + if LM.lm_type == "rnn": + state = ( + lm_states[0][:, count, :].unsqueeze(1), + lm_states[1][:, count, :].unsqueeze(1), + ) + count += 1 + else: + state_cost = hyp.state_cost + + new_hyp = Hypothesis( + ys=ys, + log_prob=hyp_log_prob, + state=state, + lm_score=lm_score, + state_cost=state_cost, + ) + B[i].add(new_hyp) + + B = B + finalized_B + best_hyps = [b.get_most_probable(length_norm=True) for b in B] + + sorted_ans = [h.ys[context_size:] for h in best_hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +def modified_beam_search_lm_shallow_fusion( + model: nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + LM: LmScorer, + beam: int = 4, + return_timestamps: bool = False, +) -> List[List[int]]: + """Modified_beam_search + NN LM shallow fusion + + Args: + model (Transducer): + The transducer model + encoder_out (torch.Tensor): + Encoder output in (N,T,C) + encoder_out_lens (torch.Tensor): + A 1-D tensor of shape (N,), containing the number of + valid frames in encoder_out before padding. + sp: + Sentence piece generator. + LM (LmScorer): + A neural net LM, e.g RNN or Transformer + beam (int, optional): + Beam size. Defaults to 4. + + Returns: + Return a list-of-list of token IDs. ans[i] is the decoding results + for the i-th utterance. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + assert LM is not None + lm_scale = LM.lm_scale + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + sos_id = getattr(LM, "sos_id", 1) + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + # get initial lm score and lm state by scoring the "sos" token + sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) + lens = torch.tensor([1]).to(device) + init_score, init_states = LM.score_token(sos_token, lens) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + state=init_states, + lm_score=init_score.reshape(-1), + timestamp=[], + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for (t, batch_size) in enumerate(batch_size_list): + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] # get batch + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) + + lm_scores = torch.cat( + [hyp.lm_score.reshape(1, -1) for hyps in A for hyp in hyps] + ) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + """ + for all hyps with a non-blank new token, score this token. + It is a little confusing here because this for-loop + looks very similar to the one below. Here, we go through all + top-k tokens and only add the non-blanks ones to the token_list. + `LM` will score those tokens given the LM states. Note that + the variable `scores` is the LM score after seeing the new + non-blank token. + """ + token_list = [] # a list of list + hs = [] + cs = [] + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_token = topk_token_indexes[k] + if new_token not in (blank_id, unk_id): + if LM.lm_type == "rnn": + token_list.append([new_token]) + # store the LSTM states + hs.append(hyp.state[0]) + cs.append(hyp.state[1]) + else: + # for transformer LM + token_list.append( + [sos_id] + hyp.ys[context_size:] + [new_token] + ) + + if len(token_list) != 0: + x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device) + if LM.lm_type == "rnn": + tokens_to_score = ( + torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) + ) + hs = torch.cat(hs, dim=1).to(device) + cs = torch.cat(cs, dim=1).to(device) + state = (hs, cs) + else: + # for transformer LM + tokens_list = [torch.tensor(tokens) for tokens in token_list] + tokens_to_score = ( + torch.nn.utils.rnn.pad_sequence( + tokens_list, batch_first=True, padding_value=0.0 + ) + .to(device) + .to(torch.int64) + ) + + state = None + + scores, lm_states = LM.score_token(tokens_to_score, x_lens, state) + + count = 0 # index, used to locate score and lm states + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + ys = hyp.ys[:] + + lm_score = hyp.lm_score + state = hyp.state + + hyp_log_prob = topk_log_probs[k] # get score of current hyp + new_token = topk_token_indexes[k] + new_timestamp = hyp.timestamp[:] + if new_token not in (blank_id, unk_id): + + ys.append(new_token) + new_timestamp.append(t) + + hyp_log_prob += lm_score[new_token] * lm_scale # add the lm score + + lm_score = scores[count] + if LM.lm_type == "rnn": + state = ( + lm_states[0][:, count, :].unsqueeze(1), + lm_states[1][:, count, :].unsqueeze(1), + ) + count += 1 + + new_hyp = Hypothesis( + ys=ys, + log_prob=hyp_log_prob, + state=state, + lm_score=lm_score, + timestamp=new_timestamp, + ) + B[i].add(new_hyp) + + B = B + finalized_B + best_hyps = [b.get_most_probable(length_norm=True) for b in B] + + sorted_ans = [h.ys[context_size:] for h in best_hyps] + sorted_timestamps = [h.timestamp for h in best_hyps] + ans = [] + ans_timestamps = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) + + if not return_timestamps: + return ans + else: + return DecodingResults( + hyps=ans, + timestamps=ans_timestamps, + ) diff --git a/egs/multi_zh-hans/ASR/zipformer/decode.py b/egs/multi_zh-hans/ASR/zipformer/decode.py index 0fbc8244b..93680602e 100755 --- a/egs/multi_zh-hans/ASR/zipformer/decode.py +++ b/egs/multi_zh-hans/ASR/zipformer/decode.py @@ -1,8 +1,7 @@ #!/usr/bin/env python3 # -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao -# Mingshuang Luo) +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -21,53 +20,59 @@ Usage: (1) greedy search ./zipformer/decode.py \ - --epoch 35 \ + --epoch 28 \ --avg 15 \ --exp-dir ./zipformer/exp \ - --lang-dir data/lang_char \ --max-duration 600 \ --decoding-method greedy_search -(2) modified beam search +(2) beam search (not recommended) ./zipformer/decode.py \ - --epoch 35 \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./zipformer/decode.py \ + --epoch 28 \ --avg 15 \ --exp-dir ./zipformer/exp \ - --lang-dir data/lang_char \ --max-duration 600 \ --decoding-method modified_beam_search \ --beam-size 4 -(3) fast beam search (trivial_graph) +(4) fast beam search (one best) ./zipformer/decode.py \ - --epoch 35 \ + --epoch 28 \ --avg 15 \ --exp-dir ./zipformer/exp \ - --lang-dir data/lang_char \ --max-duration 600 \ --decoding-method fast_beam_search \ --beam 20.0 \ --max-contexts 8 \ --max-states 64 -(4) fast beam search (LG) +(5) fast beam search (nbest) ./zipformer/decode.py \ - --epoch 30 \ + --epoch 28 \ --avg 15 \ --exp-dir ./zipformer/exp \ - --lang-dir data/lang_char \ --max-duration 600 \ - --decoding-method fast_beam_search_LG \ + --decoding-method fast_beam_search_nbest \ --beam 20.0 \ --max-contexts 8 \ - --max-states 64 + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 -(5) fast beam search (nbest oracle WER) +(6) fast beam search (nbest oracle WER) ./zipformer/decode.py \ - --epoch 35 \ + --epoch 28 \ --avg 15 \ --exp-dir ./zipformer/exp \ - --lang-dir data/lang_char \ --max-duration 600 \ --decoding-method fast_beam_search_nbest_oracle \ --beam 20.0 \ @@ -75,6 +80,17 @@ Usage: --max-states 64 \ --num-paths 200 \ --nbest-scale 0.5 + +(7) fast beam search (with LG) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 """ @@ -86,9 +102,10 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple import k2 +import sentencepiece as spm import torch import torch.nn as nn -from asr_datamodule import WenetSpeechAsrDataModule +from asr_datamodule import LibriSpeechAsrDataModule from beam_search import ( beam_search, fast_beam_search_nbest, @@ -99,10 +116,8 @@ from beam_search import ( greedy_search_batch, modified_beam_search, ) -from lhotse.cut import Cut -from train import add_model_arguments, get_model, get_params +from train import add_model_arguments, get_params, get_model -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -173,10 +188,17 @@ def get_parser(): help="The experiment dir", ) + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + parser.add_argument( "--lang-dir", type=Path, - default="data/lang_char", + default="data/lang_bpe_500", help="The lang dir containing word table and LG graph", ) @@ -186,11 +208,13 @@ def get_parser(): default="greedy_search", help="""Possible values are: - greedy_search + - beam_search - modified_beam_search - fast_beam_search - - fast_beam_search_LG + - fast_beam_search_nbest - fast_beam_search_nbest_oracle - If you use fast_beam_search_LG, you have to specify + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify `--lang-dir`, which should contain `LG.pt`. """, ) @@ -212,7 +236,7 @@ def get_parser(): search (i.e., `cutoff = max-score - beam`), which is the same as the `beam` in Kaldi. Used only when --decoding-method is fast_beam_search, - fast_beam_search, fast_beam_search_LG, + fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle """, ) @@ -222,27 +246,17 @@ def get_parser(): type=float, default=0.01, help=""" - Used only when --decoding_method is fast_beam_search_LG. + Used only when --decoding_method is fast_beam_search_nbest_LG. It specifies the scale for n-gram LM scores. """, ) - parser.add_argument( - "--ilme-scale", - type=float, - default=0.2, - help=""" - Used only when --decoding_method is fast_beam_search_LG. - It specifies the scale for the internal language model estimation. - """, - ) - parser.add_argument( "--max-contexts", type=int, default=8, help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search, fast_beam_search_LG, + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) @@ -251,7 +265,7 @@ def get_parser(): type=int, default=64, help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search, fast_beam_search_LG, + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) @@ -259,9 +273,9 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) - parser.add_argument( "--max-sym-per-frame", type=int, @@ -275,7 +289,8 @@ def get_parser(): type=int, default=200, help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest_oracle""", + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) parser.add_argument( @@ -283,19 +298,8 @@ def get_parser(): type=float, default=0.5, help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--blank-penalty", - type=float, - default=0.0, - help=""" - The penalty applied on blank symbol during decoding. - Note: It is a positive value that would be applied to logits like - this `logits[:, 0] -= blank_penalty` (suppose logits.shape is - [batch_size, vocab] and blank id is 0). - """, + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) add_model_arguments(parser) @@ -306,9 +310,9 @@ def get_parser(): def decode_one_batch( params: AttributeDict, model: nn.Module, - lexicon: Lexicon, - graph_compiler: CharCtcTrainingGraphCompiler, + sp: spm.SentencePieceProcessor, batch: dict, + word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the @@ -326,12 +330,16 @@ def decode_one_batch( It's the return value of :func:`get_params`. model: The neural model. + sp: + The BPE model. batch: It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation for the format of the `batch`. + word_table: + The word symbol table. decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used only when --decoding_method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: @@ -358,13 +366,7 @@ def decode_one_batch( value=LOG_EPS, ) - x, x_lens = model.encoder_embed(feature, feature_lens) - - src_key_padding_mask = make_pad_mask(x_lens) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) hyps = [] @@ -377,12 +379,11 @@ def decode_one_batch( beam=params.beam, max_contexts=params.max_contexts, max_states=params.max_states, - blank_penalty=params.blank_penalty, ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "fast_beam_search_LG": - hyp_tokens = fast_beam_search_one_best( + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( model=model, decoding_graph=decoding_graph, encoder_out=encoder_out, @@ -390,12 +391,25 @@ def decode_one_batch( beam=params.beam, max_contexts=params.max_contexts, max_states=params.max_states, - blank_penalty=params.blank_penalty, - ilme_scale=params.ilme_scale, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, ) for hyp in hyp_tokens: - sentence = "".join([lexicon.word_table[i] for i in hyp]) - hyps.append(list(sentence)) + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) elif params.decoding_method == "fast_beam_search_nbest_oracle": hyp_tokens = fast_beam_search_nbest_oracle( model=model, @@ -406,85 +420,81 @@ def decode_one_batch( max_contexts=params.max_contexts, max_states=params.max_states, num_paths=params.num_paths, - ref_texts=graph_compiler.texts_to_ids(supervisions["text"]), + ref_texts=sp.encode(supervisions["text"]), nbest_scale=params.nbest_scale, - blank_penalty=params.blank_penalty, ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, - blank_penalty=params.blank_penalty, ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) elif params.decoding_method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, - blank_penalty=params.blank_penalty, beam=params.beam_size, ) - for i in range(encoder_out.size(0)): - hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) else: batch_size = encoder_out.size(0) for i in range(batch_size): # fmt: off - encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] # fmt: on if params.decoding_method == "greedy_search": hyp = greedy_search( model=model, encoder_out=encoder_out_i, max_sym_per_frame=params.max_sym_per_frame, - blank_penalty=params.blank_penalty, ) elif params.decoding_method == "beam_search": hyp = beam_search( model=model, encoder_out=encoder_out_i, beam=params.beam_size, - blank_penalty=params.blank_penalty, ) else: raise ValueError( f"Unsupported decoding method: {params.decoding_method}" ) - hyps.append([lexicon.token_table[idx] for idx in hyp]) + hyps.append(sp.decode(hyp).split()) - key = f"blank_penalty_{params.blank_penalty}" if params.decoding_method == "greedy_search": - return {"greedy_search_" + key: hyps} + return {"greedy_search": hyps} elif "fast_beam_search" in params.decoding_method: - key += f"_beam_{params.beam}_" + key = f"beam_{params.beam}_" key += f"max_contexts_{params.max_contexts}_" key += f"max_states_{params.max_states}" if "nbest" in params.decoding_method: key += f"_num_paths_{params.num_paths}_" key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ilme_scale_{params.ilme_scale}" - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" return {key: hyps} else: - return {f"beam_size_{params.beam_size}_" + key: hyps} + return {f"beam_size_{params.beam_size}": hyps} def decode_dataset( dl: torch.utils.data.DataLoader, params: AttributeDict, model: nn.Module, - lexicon: Lexicon, - graph_compiler: CharCtcTrainingGraphCompiler, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -494,8 +504,12 @@ def decode_dataset( It is returned by :func:`get_params`. model: The neural model. + sp: + The BPE model. + word_table: + The word symbol table. decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used only when --decoding_method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: @@ -520,15 +534,14 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] - texts = [list("".join(text.split())) for text in texts] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( params=params, model=model, - lexicon=lexicon, - graph_compiler=graph_compiler, + sp=sp, decoding_graph=decoding_graph, + word_table=word_table, batch=batch, ) @@ -536,7 +549,8 @@ def decode_dataset( this_batch = [] assert len(hyps) == len(texts) for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - this_batch.append((cut_id, ref_text, hyp_words)) + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) results[name].extend(this_batch) @@ -545,14 +559,16 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): @@ -578,7 +594,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -596,7 +613,7 @@ def save_results( @torch.no_grad() def main(): parser = get_parser() - WenetSpeechAsrDataModule.add_arguments(parser) + LibriSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -606,10 +623,11 @@ def main(): assert params.decoding_method in ( "greedy_search", "beam_search", - "modified_beam_search", "fast_beam_search", - "fast_beam_search_LG", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", + "modified_beam_search", ) params.res_dir = params.exp_dir / params.decoding_method @@ -635,15 +653,15 @@ def main(): if "nbest" in params.decoding_method: params.suffix += f"-nbest-scale-{params.nbest_scale}" params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"_ilme_scale_{params.ilme_scale}" - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - params.suffix += f"-blank-penalty-{params.blank_penalty}" if params.use_averaged_model: params.suffix += "-use-averaged-model" @@ -657,14 +675,13 @@ def main(): logging.info(f"Device: {device}") - lexicon = Lexicon(params.lang_dir) - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - ) + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() logging.info(params) @@ -673,9 +690,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -702,9 +719,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -752,8 +769,9 @@ def main(): model.eval() if "fast_beam_search" in params.decoding_method: - if "LG" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( @@ -761,47 +779,37 @@ def main(): ) decoding_graph.scores *= params.ngram_lm_scale else: - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + word_table = None + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None + word_table = None num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") # we need cut ids to display recognition results. args.return_cuts = True - wenetspeech = WenetSpeechAsrDataModule(args) + librispeech = LibriSpeechAsrDataModule(args) - def remove_short_utt(c: Cut): - T = ((c.num_frames - 7) // 2 + 1) // 2 - if T <= 0: - logging.warning( - f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}." - ) - return T > 0 + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() - dev_cuts = wenetspeech.valid_cuts() - dev_cuts = dev_cuts.filter(remove_short_utt) - dev_dl = wenetspeech.valid_dataloaders(dev_cuts) + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) - test_net_cuts = wenetspeech.test_net_cuts() - test_net_cuts = test_net_cuts.filter(remove_short_utt) - test_net_dl = wenetspeech.test_dataloaders(test_net_cuts) + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] - test_meeting_cuts = wenetspeech.test_meeting_cuts() - test_meeting_cuts = test_meeting_cuts.filter(remove_short_utt) - test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts) - - test_sets = ["DEV", "TEST_NET", "TEST_MEETING"] - test_dls = [dev_dl, test_net_dl, test_meeting_dl] - - for test_set, test_dl in zip(test_sets, test_dls): + for test_set, test_dl in zip(test_sets, test_dl): results_dict = decode_dataset( dl=test_dl, params=params, model=model, - lexicon=lexicon, - graph_compiler=graph_compiler, + sp=sp, + word_table=word_table, decoding_graph=decoding_graph, ) diff --git a/egs/multi_zh-hans/ASR/zipformer/decode_stream.py b/egs/multi_zh-hans/ASR/zipformer/decode_stream.py deleted file mode 120000 index b8d8ddfc4..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/decode_stream.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/decode_stream.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/decode_stream.py b/egs/multi_zh-hans/ASR/zipformer/decode_stream.py new file mode 100644 index 000000000..946db275c --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/decode_stream.py @@ -0,0 +1,148 @@ +# Copyright 2022 Xiaomi Corp. (authors: Wei Kang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple + +import k2 +import torch +from beam_search import Hypothesis, HypothesisList + +from icefall.utils import AttributeDict + + +class DecodeStream(object): + def __init__( + self, + params: AttributeDict, + cut_id: str, + initial_states: List[torch.Tensor], + decoding_graph: Optional[k2.Fsa] = None, + device: torch.device = torch.device("cpu"), + ) -> None: + """ + Args: + initial_states: + Initial decode states of the model, e.g. the return value of + `get_init_state` in conformer.py + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a HLG. + Used only when decoding_method is fast_beam_search. + device: + The device to run this stream. + """ + if params.decoding_method == "fast_beam_search": + assert decoding_graph is not None + assert device == decoding_graph.device + + self.params = params + self.cut_id = cut_id + self.LOG_EPS = math.log(1e-10) + + self.states = initial_states + + # It contains a 2-D tensors representing the feature frames. + self.features: torch.Tensor = None + + self.num_frames: int = 0 + # how many frames have been processed. (before subsampling). + # we only modify this value in `func:get_feature_frames`. + self.num_processed_frames: int = 0 + + self._done: bool = False + + # The transcript of current utterance. + self.ground_truth: str = "" + + # The decoding result (partial or final) of current utterance. + self.hyp: List = [] + + # how many frames have been processed, at encoder output + self.done_frames: int = 0 + + # The encoder_embed subsample features (T - 7) // 2 + # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling + self.pad_length = 7 + 2 * 3 + + if params.decoding_method == "greedy_search": + self.hyp = [params.blank_id] * params.context_size + elif params.decoding_method == "modified_beam_search": + self.hyps = HypothesisList() + self.hyps.add( + Hypothesis( + ys=[params.blank_id] * params.context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) + elif params.decoding_method == "fast_beam_search": + # The rnnt_decoding_stream for fast_beam_search. + self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( + decoding_graph + ) + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + + @property + def done(self) -> bool: + """Return True if all the features are processed.""" + return self._done + + @property + def id(self) -> str: + return self.cut_id + + def set_features( + self, + features: torch.Tensor, + tail_pad_len: int = 0, + ) -> None: + """Set features tensor of current utterance.""" + assert features.dim() == 2, features.dim() + self.features = torch.nn.functional.pad( + features, + (0, 0, 0, self.pad_length + tail_pad_len), + mode="constant", + value=self.LOG_EPS, + ) + self.num_frames = self.features.size(0) + + def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]: + """Consume chunk_size frames of features""" + chunk_length = chunk_size + self.pad_length + + ret_length = min(self.num_frames - self.num_processed_frames, chunk_length) + + ret_features = self.features[ + self.num_processed_frames : self.num_processed_frames + ret_length # noqa + ] + + self.num_processed_frames += chunk_size + if self.num_processed_frames >= self.num_frames: + self._done = True + + return ret_features, ret_length + + def decoding_result(self) -> List[int]: + """Obtain current decoding result.""" + if self.params.decoding_method == "greedy_search": + return self.hyp[self.params.context_size :] # noqa + elif self.params.decoding_method == "modified_beam_search": + best_hyp = self.hyps.get_most_probable(length_norm=True) + return best_hyp.ys[self.params.context_size :] # noqa + else: + assert self.params.decoding_method == "fast_beam_search" + return self.hyp diff --git a/egs/multi_zh-hans/ASR/zipformer/decoder.py b/egs/multi_zh-hans/ASR/zipformer/decoder.py deleted file mode 120000 index 5a8018680..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/decoder.py b/egs/multi_zh-hans/ASR/zipformer/decoder.py new file mode 100644 index 000000000..e8db988f6 --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/decoder.py @@ -0,0 +1,122 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from scaling import Balancer + + +class Decoder(nn.Module): + """This class modifies the stateless decoder from the following paper: + + RNN-transducer with stateless prediction network + https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 + + It removes the recurrent connection from the decoder, i.e., the prediction + network. Different from the above paper, it adds an extra Conv1d + right after the embedding layer. + + TODO: Implement https://arxiv.org/pdf/2109.07513.pdf + """ + + def __init__( + self, + vocab_size: int, + decoder_dim: int, + blank_id: int, + context_size: int, + ): + """ + Args: + vocab_size: + Number of tokens of the modeling unit including blank. + decoder_dim: + Dimension of the input embedding, and of the decoder output. + blank_id: + The ID of the blank symbol. + context_size: + Number of previous words to use to predict the next word. + 1 means bigram; 2 means trigram. n means (n+1)-gram. + """ + super().__init__() + + self.embedding = nn.Embedding( + num_embeddings=vocab_size, + embedding_dim=decoder_dim, + ) + # the balancers are to avoid any drift in the magnitude of the + # embeddings, which would interact badly with parameter averaging. + self.balancer = Balancer(decoder_dim, channel_dim=-1, + min_positive=0.0, max_positive=1.0, + min_abs=0.5, max_abs=1.0, + prob=0.05) + + self.blank_id = blank_id + + assert context_size >= 1, context_size + self.context_size = context_size + self.vocab_size = vocab_size + + if context_size > 1: + self.conv = nn.Conv1d( + in_channels=decoder_dim, + out_channels=decoder_dim, + kernel_size=context_size, + padding=0, + groups=decoder_dim // 4, # group size == 4 + bias=False, + ) + self.balancer2 = Balancer(decoder_dim, channel_dim=-1, + min_positive=0.0, max_positive=1.0, + min_abs=0.5, max_abs=1.0, + prob=0.05) + + def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, U). + need_pad: + True to left pad the input. Should be True during training. + False to not pad the input. Should be False during inference. + Returns: + Return a tensor of shape (N, U, decoder_dim). + """ + y = y.to(torch.int64) + # this stuff about clamp() is a temporary fix for a mismatch + # at utterance start, we use negative ids in beam_search.py + embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) + + embedding_out = self.balancer(embedding_out) + + if self.context_size > 1: + embedding_out = embedding_out.permute(0, 2, 1) + if need_pad is True: + embedding_out = F.pad( + embedding_out, pad=(self.context_size - 1, 0) + ) + else: + # During inference time, there is no need to do extra padding + # as we only need one output + assert embedding_out.size(-1) == self.context_size + embedding_out = self.conv(embedding_out) + embedding_out = embedding_out.permute(0, 2, 1) + embedding_out = F.relu(embedding_out) + embedding_out = self.balancer2(embedding_out) + + return embedding_out diff --git a/egs/multi_zh-hans/ASR/zipformer/encoder_interface.py b/egs/multi_zh-hans/ASR/zipformer/encoder_interface.py deleted file mode 120000 index 653c5b09a..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/encoder_interface.py b/egs/multi_zh-hans/ASR/zipformer/encoder_interface.py new file mode 100644 index 000000000..257facce4 --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/encoder_interface.py @@ -0,0 +1,43 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import torch +import torch.nn as nn + + +class EncoderInterface(nn.Module): + def forward( + self, x: torch.Tensor, x_lens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A tensor of shape (batch_size, input_seq_len, num_features) + containing the input features. + x_lens: + A tensor of shape (batch_size,) containing the number of frames + in `x` before padding. + Returns: + Return a tuple containing two tensors: + - encoder_out, a tensor of (batch_size, out_seq_len, output_dim) + containing unnormalized probabilities, i.e., the output of a + linear layer. + - encoder_out_lens, a tensor of shape (batch_size,) containing + the number of frames in `encoder_out` before padding. + """ + raise NotImplementedError("Please implement it in a subclass") diff --git a/egs/multi_zh-hans/ASR/zipformer/export-onnx-streaming.py b/egs/multi_zh-hans/ASR/zipformer/export-onnx-streaming.py deleted file mode 120000 index 2962eb784..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/export-onnx-streaming.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/export-onnx-streaming.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/export-onnx-streaming.py b/egs/multi_zh-hans/ASR/zipformer/export-onnx-streaming.py new file mode 100755 index 000000000..3eb06f68c --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/export-onnx-streaming.py @@ -0,0 +1,776 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang) +# Copyright 2023 Danqing Fu (danqing.fu@gmail.com) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/tokens.txt" +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./zipformer/export-onnx-streaming.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --num-encoder-layers "2,2,3,4,3,2" \ + --downsampling-factor "1,2,4,8,4,2" \ + --feedforward-dim "512,768,1024,1536,1024,768" \ + --num-heads "4,4,4,8,4,4" \ + --encoder-dim "192,256,384,512,384,256" \ + --query-head-dim 32 \ + --value-head-dim 12 \ + --pos-head-dim 4 \ + --pos-dim 48 \ + --encoder-unmasked-dim "192,192,256,256,256,192" \ + --cnn-module-kernel "31,31,15,15,15,31" \ + --decoder-dim 512 \ + --joiner-dim 512 \ + --causal True \ + --chunk-size 16 \ + --left-context-frames 64 + +The --chunk-size in training is "16,32,64,-1", so we select one of them +(excluding -1) during streaming export. The same applies to `--left-context`, +whose value is "64,128,256,-1". + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1-chunk-16-left-64.onnx + - decoder-epoch-99-avg-1-chunk-16-left-64.onnx + - joiner-epoch-99-avg-1-chunk-16-left-64.onnx + +See ./onnx_pretrained-streaming.py for how to use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, List, Tuple + +import k2 +import onnx +import torch +import torch.nn as nn +from decoder import Decoder +from export import num_tokens +from onnxruntime.quantization import QuantType, quantize_dynamic +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_model, get_params +from zipformer import Zipformer2 + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxEncoder(nn.Module): + """A wrapper for Zipformer and the encoder_proj from the joiner""" + + def __init__( + self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear + ): + """ + Args: + encoder: + A Zipformer encoder. + encoder_proj: + The projection layer for encoder from the joiner. + """ + super().__init__() + self.encoder = encoder + self.encoder_embed = encoder_embed + self.encoder_proj = encoder_proj + self.chunk_size = encoder.chunk_size[0] + self.left_context_len = encoder.left_context_frames[0] + self.pad_length = 7 + 2 * 3 + + def forward( + self, + x: torch.Tensor, + states: List[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: + N = x.size(0) + T = self.chunk_size * 2 + self.pad_length + x_lens = torch.tensor([T] * N, device=x.device) + left_context_len = self.left_context_len + + cached_embed_left_pad = states[-2] + x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward( + x=x, + x_lens=x_lens, + cached_left_pad=cached_embed_left_pad, + ) + assert x.size(1) == self.chunk_size, (x.size(1), self.chunk_size) + + src_key_padding_mask = torch.zeros(N, self.chunk_size, dtype=torch.bool) + + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_len, device=x.device).expand( + x.size(0), left_context_len + ) + processed_lens = states[-1] # (batch,) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + new_processed_lens = processed_lens + x_lens + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + x = x.permute(1, 0, 2) + encoder_states = states[:-2] + logging.info(f"len_encoder_states={len(encoder_states)}") + ( + encoder_out, + encoder_out_lens, + new_encoder_states, + ) = self.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=encoder_states, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) + encoder_out = self.encoder_proj(encoder_out) + # Now encoder_out is of shape (N, T, joiner_dim) + + new_states = new_encoder_states + [ + new_cached_embed_left_pad, + new_processed_lens, + ] + + return encoder_out, new_states + + def get_init_states( + self, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), + ) -> List[torch.Tensor]: + """ + Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + states[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + """ + states = self.encoder.get_init_states(batch_size, device) + + embed_states = self.encoder_embed.get_init_states(batch_size, device) + + states.append(embed_states) + + processed_lens = torch.zeros(batch_size, dtype=torch.int64, device=device) + states.append(processed_lens) + + return states + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.output_linear(torch.tanh(logit)) + return logit + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + encoder_model.encoder.__class__.forward = ( + encoder_model.encoder.__class__.streaming_forward + ) + + decode_chunk_len = encoder_model.chunk_size * 2 + # The encoder_embed subsample features (T - 7) // 2 + # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling + T = decode_chunk_len + encoder_model.pad_length + + x = torch.rand(1, T, 80, dtype=torch.float32) + init_state = encoder_model.get_init_states() + num_encoders = len(encoder_model.encoder.encoder_dim) + logging.info(f"num_encoders: {num_encoders}") + logging.info(f"len(init_state): {len(init_state)}") + + inputs = {} + input_names = ["x"] + + outputs = {} + output_names = ["encoder_out"] + + def build_inputs_outputs(tensors, i): + assert len(tensors) == 6, len(tensors) + + # (downsample_left, batch_size, key_dim) + name = f"cached_key_{i}" + logging.info(f"{name}.shape: {tensors[0].shape}") + inputs[name] = {1: "N"} + outputs[f"new_{name}"] = {1: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + # (1, batch_size, downsample_left, nonlin_attn_head_dim) + name = f"cached_nonlin_attn_{i}" + logging.info(f"{name}.shape: {tensors[1].shape}") + inputs[name] = {1: "N"} + outputs[f"new_{name}"] = {1: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + # (downsample_left, batch_size, value_dim) + name = f"cached_val1_{i}" + logging.info(f"{name}.shape: {tensors[2].shape}") + inputs[name] = {1: "N"} + outputs[f"new_{name}"] = {1: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + # (downsample_left, batch_size, value_dim) + name = f"cached_val2_{i}" + logging.info(f"{name}.shape: {tensors[3].shape}") + inputs[name] = {1: "N"} + outputs[f"new_{name}"] = {1: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + # (batch_size, embed_dim, conv_left_pad) + name = f"cached_conv1_{i}" + logging.info(f"{name}.shape: {tensors[4].shape}") + inputs[name] = {0: "N"} + outputs[f"new_{name}"] = {0: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + # (batch_size, embed_dim, conv_left_pad) + name = f"cached_conv2_{i}" + logging.info(f"{name}.shape: {tensors[5].shape}") + inputs[name] = {0: "N"} + outputs[f"new_{name}"] = {0: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + num_encoder_layers = ",".join(map(str, encoder_model.encoder.num_encoder_layers)) + encoder_dims = ",".join(map(str, encoder_model.encoder.encoder_dim)) + cnn_module_kernels = ",".join(map(str, encoder_model.encoder.cnn_module_kernel)) + ds = encoder_model.encoder.downsampling_factor + left_context_len = encoder_model.left_context_len + left_context_len = [left_context_len // k for k in ds] + left_context_len = ",".join(map(str, left_context_len)) + query_head_dims = ",".join(map(str, encoder_model.encoder.query_head_dim)) + value_head_dims = ",".join(map(str, encoder_model.encoder.value_head_dim)) + num_heads = ",".join(map(str, encoder_model.encoder.num_heads)) + + meta_data = { + "model_type": "zipformer2", + "version": "1", + "model_author": "k2-fsa", + "comment": "streaming zipformer2", + "decode_chunk_len": str(decode_chunk_len), # 32 + "T": str(T), # 32+7+2*3=45 + "num_encoder_layers": num_encoder_layers, + "encoder_dims": encoder_dims, + "cnn_module_kernels": cnn_module_kernels, + "left_context_len": left_context_len, + "query_head_dims": query_head_dims, + "value_head_dims": value_head_dims, + "num_heads": num_heads, + } + logging.info(f"meta_data: {meta_data}") + + for i in range(len(init_state[:-2]) // 6): + build_inputs_outputs(init_state[i * 6 : (i + 1) * 6], i) + + # (batch_size, channels, left_pad, freq) + embed_states = init_state[-2] + name = "embed_states" + logging.info(f"{name}.shape: {embed_states.shape}") + inputs[name] = {0: "N"} + outputs[f"new_{name}"] = {0: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + # (batch_size,) + processed_lens = init_state[-1] + name = "processed_lens" + logging.info(f"{name}.shape: {processed_lens.shape}") + inputs[name] = {0: "N"} + outputs[f"new_{name}"] = {0: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + logging.info(inputs) + logging.info(outputs) + logging.info(input_names) + logging.info(output_names) + + torch.onnx.export( + encoder_model, + (x, init_state), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=input_names, + output_names=output_names, + dynamic_axes={ + "x": {0: "N"}, + "encoder_out": {0: "N"}, + **inputs, + **outputs, + }, + ) + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + + +def export_decoder_model_onnx( + decoder_model: OnnxDecoder, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + + y = torch.zeros(10, context_size, dtype=torch.int64) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_embed=model.encoder_embed, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + suffix += f"-chunk-{params.chunk_size}" + suffix += f"-left-{params.left_context_frames}" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul", "Gather"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/multi_zh-hans/ASR/zipformer/export-onnx.py b/egs/multi_zh-hans/ASR/zipformer/export-onnx.py deleted file mode 120000 index 70a15683c..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/export-onnx.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/export-onnx.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/export-onnx.py b/egs/multi_zh-hans/ASR/zipformer/export-onnx.py new file mode 100755 index 000000000..724fdd2a6 --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/export-onnx.py @@ -0,0 +1,621 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang) +# Copyright 2023 Danqing Fu (danqing.fu@gmail.com) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/tokens.txt" +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./zipformer/export-onnx.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --num-encoder-layers "2,2,3,4,3,2" \ + --downsampling-factor "1,2,4,8,4,2" \ + --feedforward-dim "512,768,1024,1536,1024,768" \ + --num-heads "4,4,4,8,4,4" \ + --encoder-dim "192,256,384,512,384,256" \ + --query-head-dim 32 \ + --value-head-dim 12 \ + --pos-head-dim 4 \ + --pos-dim 48 \ + --encoder-unmasked-dim "192,192,256,256,256,192" \ + --cnn-module-kernel "31,31,15,15,15,31" \ + --decoder-dim 512 \ + --joiner-dim 512 \ + --causal False \ + --chunk-size "16,32,64,-1" \ + --left-context-frames "64,128,256,-1" + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +See ./onnx_pretrained.py and ./onnx_check.py for how to +use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import k2 +import onnx +import torch +import torch.nn as nn +from decoder import Decoder +from export import num_tokens +from onnxruntime.quantization import QuantType, quantize_dynamic +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_model, get_params +from zipformer import Zipformer2 + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import make_pad_mask, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxEncoder(nn.Module): + """A wrapper for Zipformer and the encoder_proj from the joiner""" + + def __init__( + self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear + ): + """ + Args: + encoder: + A Zipformer encoder. + encoder_proj: + The projection layer for encoder from the joiner. + """ + super().__init__() + self.encoder = encoder + self.encoder_embed = encoder_embed + self.encoder_proj = encoder_proj + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of Zipformer.forward + + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 1-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) + - encoder_out_lens, A 1-D tensor of shape (N,) + """ + x, x_lens = self.encoder_embed(x, x_lens) + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) + encoder_out = self.encoder_proj(encoder_out) + # Now encoder_out is of shape (N, T, joiner_dim) + + return encoder_out, encoder_out_lens + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.output_linear(torch.tanh(logit)) + return logit + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T', joiner_dim) + - encoder_out_lens, a tensor of shape (N,) + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + + encoder_model = torch.jit.trace(encoder_model, (x, x_lens)) + + torch.onnx.export( + encoder_model, + (x, x_lens), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["encoder_out", "encoder_out_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + }, + ) + + meta_data = { + "model_type": "zipformer2", + "version": "1", + "model_author": "k2-fsa", + "comment": "non-streaming zipformer2", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + + +def export_decoder_model_onnx( + decoder_model: OnnxDecoder, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + + y = torch.zeros(10, context_size, dtype=torch.int64) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True) + + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_embed=model.encoder_embed, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul", "Gather"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/multi_zh-hans/ASR/zipformer/export.py b/egs/multi_zh-hans/ASR/zipformer/export.py deleted file mode 120000 index dfc1bec08..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/export.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/export.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/export.py b/egs/multi_zh-hans/ASR/zipformer/export.py new file mode 100755 index 000000000..4a48d5bad --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/export.py @@ -0,0 +1,545 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +Note: This is a example for librispeech dataset, if you are using different +dataset, you should change the argument values according to your dataset. + +(1) Export to torchscript model using torch.jit.script() + +- For non-streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +It will generate a file `jit_script.pt` in the given `exp_dir`. You can later +load it by `torch.jit.load("jit_script.pt")`. + +Check ./jit_pretrained.py for its usage. + +Check https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +- For streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +It will generate a file `jit_script_chunk_16_left_128.pt` in the given `exp_dir`. +You can later load it by `torch.jit.load("jit_script_chunk_16_left_128.pt")`. + +Check ./jit_pretrained_streaming.py for its usage. + +Check https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +(2) Export `model.state_dict()` + +- For non-streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 + +- For streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --causal 1 \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +- For non-streaming model: + +To use the generated file with `zipformer/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./zipformer/decode.py \ + --exp-dir ./zipformer/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + +- For streaming model: + +To use the generated file with `zipformer/decode.py` and `zipformer/streaming_decode.py`, you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + + # simulated streaming decoding + ./zipformer/decode.py \ + --exp-dir ./zipformer/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + + # chunk-wise streaming decoding + ./zipformer/streaming_decode.py \ + --exp-dir ./zipformer/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + +Check ./pretrained.py for its usage. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +- non-streaming model: +https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 + +- streaming model: +https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 + git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 + # You will find the pre-trained models in exp dir +""" + +import argparse +import logging +import re +from pathlib import Path +from typing import List, Tuple + +import k2 +import torch +from scaling_converter import convert_scaled_to_non_scaled +from torch import Tensor, nn +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import make_pad_mask, str2bool + + +def num_tokens( + token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$") +) -> int: + """Return the number of tokens excluding those from + disambiguation symbols. + + Caution: + 0 is not a token ID so it is excluded from the return value. + """ + symbols = token_table.symbols + ans = [] + for s in symbols: + if not disambig_pattern.match(s): + ans.append(token_table[s]) + num_tokens = len(ans) + if 0 in ans: + num_tokens -= 1 + return num_tokens + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + It will generate a file named jit_script.pt. + Check ./jit_pretrained.py for how to use it. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +class EncoderModel(nn.Module): + """A wrapper for encoder and encoder_embed""" + + def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: + super().__init__() + self.encoder = encoder + self.encoder_embed = encoder_embed + + def forward( + self, features: Tensor, feature_lengths: Tensor + ) -> Tuple[Tensor, Tensor]: + """ + Args: + features: (N, T, C) + feature_lengths: (N,) + """ + x, x_lens = self.encoder_embed(features, feature_lengths) + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return encoder_out, encoder_out_lens + + +class StreamingEncoderModel(nn.Module): + """A wrapper for encoder and encoder_embed""" + + def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: + super().__init__() + assert len(encoder.chunk_size) == 1, encoder.chunk_size + assert len(encoder.left_context_frames) == 1, encoder.left_context_frames + self.chunk_size = encoder.chunk_size[0] + self.left_context_len = encoder.left_context_frames[0] + + # The encoder_embed subsample features (T - 7) // 2 + # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling + self.pad_length = 7 + 2 * 3 + + self.encoder = encoder + self.encoder_embed = encoder_embed + + def forward( + self, features: Tensor, feature_lengths: Tensor, states: List[Tensor] + ) -> Tuple[Tensor, Tensor, List[Tensor]]: + """Streaming forward for encoder_embed and encoder. + + Args: + features: (N, T, C) + feature_lengths: (N,) + states: a list of Tensors + + Returns encoder outputs, output lengths, and updated states. + """ + chunk_size = self.chunk_size + left_context_len = self.left_context_len + + cached_embed_left_pad = states[-2] + x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward( + x=features, + x_lens=feature_lengths, + cached_left_pad=cached_embed_left_pad, + ) + assert x.size(1) == chunk_size, (x.size(1), chunk_size) + + src_key_padding_mask = make_pad_mask(x_lens) + + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_len, device=x.device).expand( + x.size(0), left_context_len + ) + processed_lens = states[-1] # (batch,) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + new_processed_lens = processed_lens + x_lens + + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + encoder_states = states[:-2] + + ( + encoder_out, + encoder_out_lens, + new_encoder_states, + ) = self.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=encoder_states, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + new_states = new_encoder_states + [ + new_cached_embed_left_pad, + new_processed_lens, + ] + return encoder_out, encoder_out_lens, new_states + + @torch.jit.export + def get_init_states( + self, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), + ) -> List[torch.Tensor]: + """ + Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + states[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + """ + states = self.encoder.get_init_states(batch_size, device) + + embed_states = self.encoder_embed.get_init_states(batch_size, device) + states.append(embed_states) + + processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) + states.append(processed_lens) + + return states + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + # if torch.cuda.is_available(): + # device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.eval() + + if params.jit is True: + convert_scaled_to_non_scaled(model, inplace=True) + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + + # Wrap encoder and encoder_embed as a module + if params.causal: + model.encoder = StreamingEncoderModel(model.encoder, model.encoder_embed) + chunk_size = model.encoder.chunk_size + left_context_len = model.encoder.left_context_len + filename = f"jit_script_chunk_{chunk_size}_left_{left_context_len}.pt" + else: + model.encoder = EncoderModel(model.encoder, model.encoder_embed) + filename = "jit_script.pt" + + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + model.save(str(params.exp_dir / filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torchscript. Export model.state_dict()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/multi_zh-hans/ASR/zipformer/jit_pretrained.py b/egs/multi_zh-hans/ASR/zipformer/jit_pretrained.py deleted file mode 120000 index 25108391f..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/jit_pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/jit_pretrained.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/jit_pretrained.py b/egs/multi_zh-hans/ASR/zipformer/jit_pretrained.py new file mode 100755 index 000000000..a41fbc1c9 --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/jit_pretrained.py @@ -0,0 +1,280 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, exported by `torch.jit.script()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +Usage of this script: + +./zipformer/jit_pretrained.py \ + --nn-model-filename ./zipformer/exp/cpu_jit.pt \ + --tokens ./data/lang_bpe_500/tokens.txt \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model-filename", + type=str, + required=True, + help="Path to the torchscript model cpu_jit.pt", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def greedy_search( + model: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (N, T, C) + encoder_out_lens: + A 1-D tensor of shape (N,). + Returns: + Return the decoded results for each utterance. + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + device = encoder_out.device + blank_id = model.decoder.blank_id + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + context_size = model.decoder.context_size + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ).squeeze(1) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = packed_encoder_out.data[start:end] + current_encoder_out = current_encoder_out + # current_encoder_out's shape: (batch_size, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + + logits = model.joiner( + current_encoder_out, + decoder_out, + ) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ) + decoder_out = decoder_out.squeeze(1) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = torch.jit.load(args.nn_model_filename) + + model.eval() + + model.to(device) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + features=features, + feature_lengths=feature_lengths, + ) + + hyps = greedy_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + + s = "\n" + + token_table = k2.SymbolTable.from_file(args.tokens) + + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + + for filename, hyp in zip(args.sound_files, hyps): + words = token_ids_to_words(hyp) + s += f"{filename}:\n{words}\n" + + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/multi_zh-hans/ASR/zipformer/jit_pretrained_streaming.py b/egs/multi_zh-hans/ASR/zipformer/jit_pretrained_streaming.py deleted file mode 120000 index 1962351e9..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/jit_pretrained_streaming.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/jit_pretrained_streaming.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/jit_pretrained_streaming.py b/egs/multi_zh-hans/ASR/zipformer/jit_pretrained_streaming.py new file mode 100755 index 000000000..d4ceacefd --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/jit_pretrained_streaming.py @@ -0,0 +1,273 @@ +#!/usr/bin/env python3 +# flake8: noqa +# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models exported by `torch.jit.script()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +Usage of this script: + +./zipformer/jit_pretrained_streaming.py \ + --nn-model-filename ./zipformer/exp-causal/jit_script_chunk_16_left_128.pt \ + --tokens ./data/lang_bpe_500/tokens.txt \ + /path/to/foo.wav \ +""" + +import argparse +import logging +import math +from typing import List, Optional + +import k2 +import kaldifeat +import torch +import torchaudio +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model-filename", + type=str, + required=True, + help="Path to the torchscript model jit_script.pt", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + decoder: torch.jit.ScriptModule, + joiner: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + decoder_out: Optional[torch.Tensor] = None, + hyp: Optional[List[int]] = None, + device: torch.device = torch.device("cpu"), +): + assert encoder_out.ndim == 2 + context_size = decoder.context_size + blank_id = decoder.blank_id + + if decoder_out is None: + assert hyp is None, hyp + hyp = [blank_id] * context_size + decoder_input = torch.tensor(hyp, dtype=torch.int32, device=device).unsqueeze(0) + # decoder_input.shape (1,, 1 context_size) + decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1) + else: + assert decoder_out.ndim == 2 + assert hyp is not None, hyp + + T = encoder_out.size(0) + for i in range(T): + cur_encoder_out = encoder_out[i : i + 1] + joiner_out = joiner(cur_encoder_out, decoder_out).squeeze(0) + y = joiner_out.argmax(dim=0).item() + + if y != blank_id: + hyp.append(y) + decoder_input = hyp[-context_size:] + + decoder_input = torch.tensor( + decoder_input, dtype=torch.int32, device=device + ).unsqueeze(0) + decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1) + + return hyp, decoder_out + + +def create_streaming_feature_extractor(sample_rate) -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = sample_rate + opts.mel_opts.num_bins = 80 + return OnlineFbank(opts) + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = torch.jit.load(args.nn_model_filename) + model.eval() + model.to(device) + + encoder = model.encoder + decoder = model.decoder + joiner = model.joiner + + token_table = k2.SymbolTable.from_file(args.tokens) + context_size = decoder.context_size + + logging.info("Constructing Fbank computer") + online_fbank = create_streaming_feature_extractor(args.sample_rate) + + logging.info(f"Reading sound files: {args.sound_file}") + wave_samples = read_sound_files( + filenames=[args.sound_file], + expected_sample_rate=args.sample_rate, + )[0] + logging.info(wave_samples.shape) + + logging.info("Decoding started") + + chunk_length = encoder.chunk_size * 2 + T = chunk_length + encoder.pad_length + + logging.info(f"chunk_length: {chunk_length}") + logging.info(f"T: {T}") + + states = encoder.get_init_states(device=device) + + tail_padding = torch.zeros(int(0.3 * args.sample_rate), dtype=torch.float32) + + wave_samples = torch.cat([wave_samples, tail_padding]) + + chunk = int(0.25 * args.sample_rate) # 0.2 second + num_processed_frames = 0 + + hyp = None + decoder_out = None + + start = 0 + while start < wave_samples.numel(): + logging.info(f"{start}/{wave_samples.numel()}") + end = min(start + chunk, wave_samples.numel()) + samples = wave_samples[start:end] + start += chunk + online_fbank.accept_waveform( + sampling_rate=args.sample_rate, + waveform=samples, + ) + while online_fbank.num_frames_ready - num_processed_frames >= T: + frames = [] + for i in range(T): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + frames = torch.cat(frames, dim=0).to(device).unsqueeze(0) + x_lens = torch.tensor([T], dtype=torch.int32, device=device) + encoder_out, out_lens, states = encoder( + features=frames, + feature_lengths=x_lens, + states=states, + ) + num_processed_frames += chunk_length + + hyp, decoder_out = greedy_search( + decoder, joiner, encoder_out.squeeze(0), decoder_out, hyp, device=device + ) + + text = "" + for i in hyp[context_size:]: + text += token_table[i] + text = text.replace("▁", " ").strip() + + logging.info(args.sound_file) + logging.info(text) + + logging.info("Decoding Done") + + +torch.set_num_threads(4) +torch.set_num_interop_threads(1) +torch._C._jit_set_profiling_executor(False) +torch._C._jit_set_profiling_mode(False) +torch._C._set_graph_executor_optimize(False) +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/multi_zh-hans/ASR/zipformer/joiner.py b/egs/multi_zh-hans/ASR/zipformer/joiner.py deleted file mode 120000 index 5b8a36332..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/joiner.py b/egs/multi_zh-hans/ASR/zipformer/joiner.py new file mode 100644 index 000000000..f03cc930e --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/joiner.py @@ -0,0 +1,66 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from scaling import ScaledLinear + + +class Joiner(nn.Module): + def __init__( + self, + encoder_dim: int, + decoder_dim: int, + joiner_dim: int, + vocab_size: int, + ): + super().__init__() + + self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25) + self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25) + self.output_linear = nn.Linear(joiner_dim, vocab_size) + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + project_input: bool = True, + ) -> torch.Tensor: + """ + Args: + encoder_out: + Output from the encoder. Its shape is (N, T, s_range, C). + decoder_out: + Output from the decoder. Its shape is (N, T, s_range, C). + project_input: + If true, apply input projections encoder_proj and decoder_proj. + If this is false, it is the user's responsibility to do this + manually. + Returns: + Return a tensor of shape (N, T, s_range, C). + """ + assert encoder_out.ndim == decoder_out.ndim, (encoder_out.shape, decoder_out.shape) + + if project_input: + logit = self.encoder_proj(encoder_out) + self.decoder_proj( + decoder_out + ) + else: + logit = encoder_out + decoder_out + + logit = self.output_linear(torch.tanh(logit)) + + return logit diff --git a/egs/multi_zh-hans/ASR/zipformer/model.py b/egs/multi_zh-hans/ASR/zipformer/model.py deleted file mode 120000 index cd7e07d72..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/model.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/model.py b/egs/multi_zh-hans/ASR/zipformer/model.py new file mode 100644 index 000000000..b541ee697 --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/model.py @@ -0,0 +1,358 @@ +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from encoder_interface import EncoderInterface + +from icefall.utils import add_sos, make_pad_mask +from scaling import ScaledLinear + + +class AsrModel(nn.Module): + def __init__( + self, + encoder_embed: nn.Module, + encoder: EncoderInterface, + decoder: Optional[nn.Module] = None, + joiner: Optional[nn.Module] = None, + encoder_dim: int = 384, + decoder_dim: int = 512, + vocab_size: int = 500, + use_transducer: bool = True, + use_ctc: bool = False, + ): + """A joint CTC & Transducer ASR model. + + - Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf) + - Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf) + - Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf) + + Args: + encoder_embed: + It is a Convolutional 2D subsampling module. It converts + an input of shape (N, T, idim) to an output of of shape + (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2. + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dim) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, decoder_dim). + It should contain one attribute: `blank_id`. + It is used when use_transducer is True. + joiner: + It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). + Its output shape is (N, T, U, vocab_size). Note that its output contains + unnormalized probs, i.e., not processed by log-softmax. + It is used when use_transducer is True. + use_transducer: + Whether use transducer head. Default: True. + use_ctc: + Whether use CTC head. Default: False. + """ + super().__init__() + + assert ( + use_transducer or use_ctc + ), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}" + + assert isinstance(encoder, EncoderInterface), type(encoder) + + self.encoder_embed = encoder_embed + self.encoder = encoder + + self.use_transducer = use_transducer + if use_transducer: + # Modules for Transducer head + assert decoder is not None + assert hasattr(decoder, "blank_id") + assert joiner is not None + + self.decoder = decoder + self.joiner = joiner + + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_scale=0.25 + ) + self.simple_lm_proj = ScaledLinear( + decoder_dim, vocab_size, initial_scale=0.25 + ) + else: + assert decoder is None + assert joiner is None + + self.use_ctc = use_ctc + if use_ctc: + # Modules for CTC head + self.ctc_output = nn.Sequential( + nn.Dropout(p=0.1), + nn.Linear(encoder_dim, vocab_size), + nn.LogSoftmax(dim=-1), + ) + + def forward_encoder( + self, x: torch.Tensor, x_lens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute encoder outputs. + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + + Returns: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + """ + # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") + x, x_lens = self.encoder_embed(x, x_lens) + # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M") + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) + + return encoder_out, encoder_out_lens + + def forward_ctc( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + targets: torch.Tensor, + target_lengths: torch.Tensor, + ) -> torch.Tensor: + """Compute CTC loss. + Args: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + targets: + Target Tensor of shape (sum(target_lengths)). The targets are assumed + to be un-padded and concatenated within 1 dimension. + """ + # Compute CTC log-prob + ctc_output = self.ctc_output(encoder_out) # (N, T, C) + + ctc_loss = torch.nn.functional.ctc_loss( + log_probs=ctc_output.permute(1, 0, 2), # (T, N, C) + targets=targets, + input_lengths=encoder_out_lens, + target_lengths=target_lengths, + reduction="sum", + ) + return ctc_loss + + def forward_transducer( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + y: k2.RaggedTensor, + y_lens: torch.Tensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute Transducer loss. + Args: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + """ + # Now for the decoder, i.e., the prediction network + blank_id = self.decoder.blank_id + sos_y = add_sos(y, sos_id=blank_id) + + # sos_y_padded: [B, S + 1], start with SOS. + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + + # decoder_out: [B, S + 1, decoder_dim] + decoder_out = self.decoder(sos_y_padded) + + # Note: y does not start with SOS + # y_padded : [B, S] + y_padded = y.pad(mode="constant", padding_value=0) + + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros( + (encoder_out.size(0), 4), + dtype=torch.int64, + device=encoder_out.device, + ) + boundary[:, 2] = y_lens + boundary[:, 3] = encoder_out_lens + + lm = self.simple_lm_proj(decoder_out) + am = self.simple_am_proj(encoder_out) + + # if self.training and random.random() < 0.25: + # lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04) + # if self.training and random.random() < 0.25: + # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) + + with torch.cuda.amp.autocast(enabled=False): + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=lm.float(), + am=am.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction="sum", + return_grad=True, + ) + + # ranges : [B, T, prune_range] + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=prune_range, + ) + + # am_pruned : [B, T, prune_range, encoder_dim] + # lm_pruned : [B, T, prune_range, decoder_dim] + am_pruned, lm_pruned = k2.do_rnnt_pruning( + am=self.joiner.encoder_proj(encoder_out), + lm=self.joiner.decoder_proj(decoder_out), + ranges=ranges, + ) + + # logits : [B, T, prune_range, vocab_size] + + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). + logits = self.joiner(am_pruned, lm_pruned, project_input=False) + + with torch.cuda.amp.autocast(enabled=False): + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + reduction="sum", + ) + + return simple_loss, pruned_loss + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: k2.RaggedTensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + Returns: + Return the transducer losses and CTC loss, + in form of (simple_loss, pruned_loss, ctc_loss) + + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.num_axes == 2, y.num_axes + + assert x.size(0) == x_lens.size(0) == y.dim0 + + # Compute encoder outputs + encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens) + + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + + if self.use_transducer: + # Compute transducer loss + simple_loss, pruned_loss = self.forward_transducer( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + y=y.to(x.device), + y_lens=y_lens, + prune_range=prune_range, + am_scale=am_scale, + lm_scale=lm_scale, + ) + else: + simple_loss = torch.empty(0) + pruned_loss = torch.empty(0) + + if self.use_ctc: + # Compute CTC loss + targets = y.values + ctc_loss = self.forward_ctc( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + targets=targets, + target_lengths=y_lens, + ) + else: + ctc_loss = torch.empty(0) + + return simple_loss, pruned_loss, ctc_loss diff --git a/egs/multi_zh-hans/ASR/zipformer/multi_dataset.py b/egs/multi_zh-hans/ASR/zipformer/multi_dataset.py index 856130ee4..d957c6cc0 100644 --- a/egs/multi_zh-hans/ASR/zipformer/multi_dataset.py +++ b/egs/multi_zh-hans/ASR/zipformer/multi_dataset.py @@ -37,92 +37,193 @@ class MultiDataset: - aishell4_cuts_train_M.jsonl.gz - aishell4_cuts_train_S.jsonl.gz - alimeeting-far_cuts_train.jsonl.gz - - cuts_L.jsonl.gz - - cuts_M.jsonl.gz - - cuts_S.jsonl.gz + - magicdata_cuts_train.jsonl.gz - primewords_cuts_train.jsonl.gz - stcmds_cuts_train.jsonl.gz - thchs_30_cuts_train.jsonl.gz + - kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz + - kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz + - wenetspeech/cuts_L.jsonl.gz + - wenetspeech/cuts_M.jsonl.gz + - wenetspeech/cuts_S.jsonl.gz """ self.fbank_dir = Path(fbank_dir) def train_cuts(self) -> CutSet: logging.info("About to get multidataset train cuts") - # LibriSpeech - logging.info("Loading LibriSpeech in lazy mode") - librispeech_cuts = load_manifest_lazy( - self.fbank_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" - ) + # THCHS-30 + logging.info("Loading THCHS-30 in lazy mode") + thchs_30_cuts = load_manifest_lazy( + self.fbank_dir / "thchs_30_cuts_train.jsonl.gz" + ) - # GigaSpeech - filenames = glob.glob(f"{self.fbank_dir}/XL_split/cuts_XL.*.jsonl.gz") + # AISHELL-1 + logging.info("Loading Aishell-1 in lazy mode") + aishell_cuts = load_manifest_lazy( + self.fbank_dir / "aishell_cuts_train.jsonl.gz" + ) - pattern = re.compile(r"cuts_XL.([0-9]+).jsonl.gz") - idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames) - idx_filenames = sorted(idx_filenames, key=lambda x: x[0]) + # AISHELL-2 + logging.info("Loading Aishell-2 in lazy mode") + aishell_2_cuts = load_manifest_lazy( + self.fbank_dir / "aishell2_cuts_train.jsonl.gz" + ) - sorted_filenames = [f[1] for f in idx_filenames] + # AISHELL-4 + logging.info("Loading Aishell-4 in lazy mode") + aishell_4_L_cuts = load_manifest_lazy( + self.fbank_dir / "aishell4_cuts_train_L.jsonl.gz" + ) + aishell_4_M_cuts = load_manifest_lazy( + self.fbank_dir / "aishell4_cuts_train_M.jsonl.gz" + ) + aishell_4_S_cuts = load_manifest_lazy( + self.fbank_dir / "aishell4_cuts_train_S.jsonl.gz" + ) - logging.info(f"Loading GigaSpeech {len(sorted_filenames)} splits in lazy mode") + # ST-CMDS + logging.info("Loading ST-CMDS in lazy mode") + stcmds_cuts = load_manifest_lazy( + self.fbank_dir / "stcmds_cuts_train.jsonl.gz" + ) - gigaspeech_cuts = lhotse.combine( - lhotse.load_manifest_lazy(p) for p in sorted_filenames - ) + # Primewords + logging.info("Loading Primewords in lazy mode") + primewords_cuts = load_manifest_lazy( + self.fbank_dir / "primewords_cuts_train.jsonl.gz" + ) - # CommonVoice - logging.info("Loading CommonVoice in lazy mode") - commonvoice_cuts = load_manifest_lazy( - self.fbank_dir / f"cv-en_cuts_train.jsonl.gz" - ) + # MagicData + logging.info("Loading MagicData in lazy mode") + magicdata_cuts = load_manifest_lazy( + self.fbank_dir / "magicdata_cuts_train.jsonl.gz" + ) - # LibriHeavy - logging.info("Loading LibriHeavy in lazy mode") - libriheavy_small_cuts = load_manifest_lazy( - self.fbank_dir / "libriheavy_cuts_train_small.jsonl.gz" - ) - libriheavy_medium_cuts = load_manifest_lazy( - self.fbank_dir / "libriheavy_cuts_train_medium.jsonl.gz" - ) - libriheavy_cuts = lhotse.combine(libriheavy_small_cuts, libriheavy_medium_cuts) + # Aidatatang_200zh + logging.info("Loading Aidatatang_200zh in lazy mode") + aidatatang_200zh_cuts = load_manifest_lazy( + self.fbank_dir / "aidatatang_cuts_train.jsonl.gz" + ) + + # Ali-Meeting + logging.info("Loading Ali-Meeting in lazy mode") + alimeeting_cuts = load_manifest_lazy( + self.fbank_dir / "alimeeting-far_cuts_train.jsonl.gz" + ) + + # WeNetSpeech + logging.info("Loading WeNetSpeech in lazy mode") + wenetspeech_L_cuts = load_manifest_lazy( + self.fbank_dir / "wenetspeech" / "cuts_L.jsonl.gz" + ) + wenetspeech_M_cuts = load_manifest_lazy( + self.fbank_dir / "wenetspeech" / "cuts_M.jsonl.gz" + ) + wenetspeech_S_cuts = load_manifest_lazy( + self.fbank_dir / "wenetspeech" / "cuts_S.jsonl.gz" + ) + + # KeSpeech + logging.info("Loading KeSpeech in lazy mode") + kespeech_1_cuts = load_manifest_lazy( + self.fbank_dir / "kespeech" / "kespeech-asr_cuts_train_phase1.jsonl.gz" + ) + kespeech_2_cuts = load_manifest_lazy( + self.fbank_dir / "kespeech" / "kespeech-asr_cuts_train_phase2.jsonl.gz" + ) return CutSet.mux( - librispeech_cuts, - gigaspeech_cuts, - commonvoice_cuts, - libriheavy_cuts, - weights=[ - len(librispeech_cuts), - len(gigaspeech_cuts), - len(commonvoice_cuts), - len(libriheavy_cuts), - ], - ) + thchs_30_cuts, + aishell_cuts, + aishell_2_cuts, + aishell_4_L_cuts, + aishell_4_M_cuts, + aishell_4_S_cuts, + stcmds_cuts, + primewords_cuts, + magicdata_cuts, + aidatatang_200zh_cuts, + alimeeting_cuts, + wenetspeech_L_cuts, + wenetspeech_M_cuts, + wenetspeech_S_cuts, + kespeech_1_cuts, + kespeech_2_cuts, + weights=[ + len(thchs_30_cuts), + len(aishell_cuts), + len(aishell_2_cuts), + len(aishell_4_L_cuts), + len(aishell_4_M_cuts), + len(aishell_4_S_cuts), + len(stcmds_cuts), + len(primewords_cuts), + len(magicdata_cuts), + len(aidatatang_200zh_cuts), + len(alimeeting_cuts), + len(wenetspeech_L_cuts), + len(wenetspeech_M_cuts), + len(wenetspeech_S_cuts), + len(kespeech_1_cuts), + len(kespeech_2_cuts), + ], + ) - def test_cuts(self) -> CutSet: - logging.info("About to get multidataset test cuts") + def dev_cuts(self) -> CutSet: + logging.info("About to get multidataset dev cuts") - # GigaSpeech - logging.info("Loading GigaSpeech DEV in lazy mode") - gigaspeech_dev_cuts = load_manifest_lazy(self.fbank_dir / "cuts_DEV.jsonl.gz") + # Aidatatang_200zh + logging.info("Loading Aidatatang_200zh DEV set in lazy mode") + aidatatang_dev_cuts = load_manifest_lazy(self.fbank_dir / "aidatatang_cuts_dev.jsonl.gz") - logging.info("Loading GigaSpeech TEST in lazy mode") - gigaspeech_test_cuts = load_manifest_lazy(self.fbank_dir / "cuts_TEST.jsonl.gz") + # AISHELL + logging.info("Loading Aishell DEV set in lazy mode") + aishell_dev_cuts = load_manifest_lazy( + self.fbank_dir / "aishell_cuts_dev.jsonl.gz" + ) - # CommonVoice - logging.info("Loading CommonVoice DEV in lazy mode") - commonvoice_dev_cuts = load_manifest_lazy( - self.fbank_dir / "cv-en_cuts_dev.jsonl.gz" - ) + # AISHELL-2 + logging.info("Loading Aishell-2 DEV set in lazy mode") + aishell2_dev_cuts = load_manifest_lazy( + self.fbank_dir / "aishell2_cuts_dev.jsonl.gz" + ) + + # Ali-Meeting + logging.info("Loading Ali-Meeting DEV set in lazy mode") + alimeeting_dev_cuts = load_manifest_lazy( + self.fbank_dir / "alimeeting-far_cuts_eval.jsonl.gz" + ) + + # MagicData + logging.info("Loading MagicData DEV set in lazy mode") + magicdata_dev_cuts = load_manifest_lazy( + self.fbank_dir / "magicdata_cuts_dev.jsonl.gz" + ) + + # KeSpeech + logging.info("Loading KeSpeech DEV set in lazy mode") + kespeech_dev_phase1_cuts = load_manifest_lazy( + self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase1.jsonl.gz" + ) + kespeech_dev_phase2_cuts = load_manifest_lazy( + self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase2.jsonl.gz" + ) + + # WeNetSpeech + logging.info("Loading WeNetSpeech DEV set in lazy mode") + wenetspeech_dev_cuts = load_manifest_lazy( + self.fbank_dir / "wenetspeech" / "cuts_DEV.jsonl.gz" + ) - logging.info("Loading CommonVoice TEST in lazy mode") - commonvoice_test_cuts = load_manifest_lazy( - self.fbank_dir / "cv-en_cuts_test.jsonl.gz" - ) return [ - gigaspeech_dev_cuts, - gigaspeech_test_cuts, - commonvoice_dev_cuts, - commonvoice_test_cuts, - ] + aidatatang_dev_cuts, + aishell_dev_cuts, + aishell2_dev_cuts, + alimeeting_dev_cuts, + magicdata_dev_cuts, + kespeech_dev_phase1_cuts, + kespeech_dev_phase2_cuts, + wenetspeech_dev_cuts, + ] diff --git a/egs/multi_zh-hans/ASR/zipformer/onnx_check.py b/egs/multi_zh-hans/ASR/zipformer/onnx_check.py deleted file mode 120000 index f3dd42004..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/onnx_check.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_check.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/onnx_check.py b/egs/multi_zh-hans/ASR/zipformer/onnx_check.py new file mode 100755 index 000000000..b38b875d0 --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/onnx_check.py @@ -0,0 +1,241 @@ +#!/usr/bin/env python3 +# +# Copyright 2022 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script checks that exported onnx models produce the same output +with the given torchscript model for the same input. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/tokens.txt" +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model via torchscript (torch.jit.script()) + +./zipformer/export.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp/ \ + --jit 1 + +It will generate the following file in $repo/exp: + - jit_script.pt + +3. Export the model to ONNX + +./zipformer/export-onnx.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +4. Run this file + +./zipformer/onnx_check.py \ + --jit-filename $repo/exp/jit_script.pt \ + --onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx +""" + +import argparse +import logging + +import torch +from onnx_pretrained import OnnxModel + +from icefall import is_module_available + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--jit-filename", + required=True, + type=str, + help="Path to the torchscript model", + ) + + parser.add_argument( + "--onnx-encoder-filename", + required=True, + type=str, + help="Path to the onnx encoder model", + ) + + parser.add_argument( + "--onnx-decoder-filename", + required=True, + type=str, + help="Path to the onnx decoder model", + ) + + parser.add_argument( + "--onnx-joiner-filename", + required=True, + type=str, + help="Path to the onnx joiner model", + ) + + return parser + + +def test_encoder( + torch_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, +): + C = 80 + for i in range(3): + N = torch.randint(low=1, high=20, size=(1,)).item() + T = torch.randint(low=30, high=50, size=(1,)).item() + logging.info(f"test_encoder: iter {i}, N={N}, T={T}") + + x = torch.rand(N, T, C) + x_lens = torch.randint(low=30, high=T + 1, size=(N,)) + x_lens[0] = T + + torch_encoder_out, torch_encoder_out_lens = torch_model.encoder(x, x_lens) + torch_encoder_out = torch_model.joiner.encoder_proj(torch_encoder_out) + + onnx_encoder_out, onnx_encoder_out_lens = onnx_model.run_encoder(x, x_lens) + + assert torch.allclose(torch_encoder_out, onnx_encoder_out, atol=1e-05), ( + (torch_encoder_out - onnx_encoder_out).abs().max() + ) + + +def test_decoder( + torch_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, +): + context_size = onnx_model.context_size + vocab_size = onnx_model.vocab_size + for i in range(10): + N = torch.randint(1, 100, size=(1,)).item() + logging.info(f"test_decoder: iter {i}, N={N}") + x = torch.randint( + low=1, + high=vocab_size, + size=(N, context_size), + dtype=torch.int64, + ) + torch_decoder_out = torch_model.decoder(x, need_pad=torch.tensor([False])) + torch_decoder_out = torch_model.joiner.decoder_proj(torch_decoder_out) + torch_decoder_out = torch_decoder_out.squeeze(1) + + onnx_decoder_out = onnx_model.run_decoder(x) + assert torch.allclose(torch_decoder_out, onnx_decoder_out, atol=1e-4), ( + (torch_decoder_out - onnx_decoder_out).abs().max() + ) + + +def test_joiner( + torch_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, +): + encoder_dim = torch_model.joiner.encoder_proj.weight.shape[1] + decoder_dim = torch_model.joiner.decoder_proj.weight.shape[1] + for i in range(10): + N = torch.randint(1, 100, size=(1,)).item() + logging.info(f"test_joiner: iter {i}, N={N}") + encoder_out = torch.rand(N, encoder_dim) + decoder_out = torch.rand(N, decoder_dim) + + projected_encoder_out = torch_model.joiner.encoder_proj(encoder_out) + projected_decoder_out = torch_model.joiner.decoder_proj(decoder_out) + + torch_joiner_out = torch_model.joiner(encoder_out, decoder_out) + onnx_joiner_out = onnx_model.run_joiner( + projected_encoder_out, projected_decoder_out + ) + + assert torch.allclose(torch_joiner_out, onnx_joiner_out, atol=1e-4), ( + (torch_joiner_out - onnx_joiner_out).abs().max() + ) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + logging.info(vars(args)) + + torch_model = torch.jit.load(args.jit_filename) + + onnx_model = OnnxModel( + encoder_model_filename=args.onnx_encoder_filename, + decoder_model_filename=args.onnx_decoder_filename, + joiner_model_filename=args.onnx_joiner_filename, + ) + + logging.info("Test encoder") + test_encoder(torch_model, onnx_model) + + logging.info("Test decoder") + test_decoder(torch_model, onnx_model) + + logging.info("Test joiner") + test_joiner(torch_model, onnx_model) + logging.info("Finished checking ONNX models") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +# See https://github.com/pytorch/pytorch/issues/38342 +# and https://github.com/pytorch/pytorch/issues/33354 +# +# If we don't do this, the delay increases whenever there is +# a new request that changes the actual batch size. +# If you use `py-spy dump --pid --native`, you will +# see a lot of time is spent in re-compiling the torch script model. +torch._C._jit_set_profiling_executor(False) +torch._C._jit_set_profiling_mode(False) +torch._C._set_graph_executor_optimize(False) +if __name__ == "__main__": + torch.manual_seed(20220727) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/multi_zh-hans/ASR/zipformer/onnx_decode.py b/egs/multi_zh-hans/ASR/zipformer/onnx_decode.py index ed5f6db08..2aca36ca9 100755 --- a/egs/multi_zh-hans/ASR/zipformer/onnx_decode.py +++ b/egs/multi_zh-hans/ASR/zipformer/onnx_decode.py @@ -2,8 +2,7 @@ # # Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, # Zengwei Yao, -# Xiaoyu Yang, -# Wei Kang) +# Xiaoyu Yang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -22,47 +21,50 @@ This script loads ONNX exported models and uses them to decode the test sets. We use the pre-trained model from -https://huggingface.co/pkufool/icefall-asr-zipformer-wenetspeech-20230615 +https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 as an example to show how to use this file. 1. Download the pre-trained model -cd egs/wenetspeech/ASR +cd egs/librispeech/ASR -repo_url=https://huggingface.co/pkufool/icefall-asr-zipformer-wenetspeech-20230615 +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_char/tokens.txt" +git lfs pull --include "data/lang_bpe_500/bpe.model" git lfs pull --include "exp/pretrained.pt" cd exp -ln -s pretrained.pt epoch-9999.pt +ln -s pretrained.pt epoch-99.pt popd 2. Export the model to ONNX ./zipformer/export-onnx.py \ - --tokens $repo/data/lang_char/tokens.txt \ - --epoch 9999 \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ --avg 1 \ - --exp-dir $repo/exp/ + --exp-dir $repo/exp \ + --causal False It will generate the following 3 files inside $repo/exp: - - encoder-epoch-9999-avg-1.onnx - - decoder-epoch-9999-avg-1.onnx - - joiner-epoch-9999-avg-1.onnx + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx 2. Run this file ./zipformer/onnx_decode.py \ - --exp-dir ./zipformer/exp \ + --exp-dir $repo/exp \ --max-duration 600 \ - --encoder-model-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \ - --decoder-model-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \ - --joiner-model-filename $repo/exp/joiner-epoch-9999-avg-1.onnx \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ """ @@ -72,14 +74,14 @@ import time from pathlib import Path from typing import List, Tuple -import k2 import torch import torch.nn as nn -from asr_datamodule import WenetSpeechAsrDataModule -from lhotse.cut import Cut -from onnx_pretrained import OnnxModel, greedy_search +from asr_datamodule import LibriSpeechAsrDataModule + +from onnx_pretrained import greedy_search, OnnxModel from icefall.utils import setup_logger, store_transcripts, write_error_stats +from k2 import SymbolTable def get_parser(): @@ -111,15 +113,14 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless7/exp", + default="zipformer/exp", help="The experiment dir", ) parser.add_argument( "--tokens", type=str, - default="data/lang_char/tokens.txt", - help="Path to the tokens.txt", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -133,7 +134,7 @@ def get_parser(): def decode_one_batch( - model: OnnxModel, token_table: k2.SymbolTable, batch: dict + model: OnnxModel, token_table: SymbolTable, batch: dict ) -> List[List[str]]: """Decode one batch and return the result. Currently it only greedy_search is supported. @@ -142,7 +143,7 @@ def decode_one_batch( model: The neural model. token_table: - Mapping ids to tokens. + The token table. batch: It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation @@ -164,14 +165,20 @@ def decode_one_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens ) - hyps = [[token_table[h] for h in hyp] for hyp in hyps] + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + + hyps = [token_ids_to_words(h).split() for h in hyps] return hyps def decode_dataset( dl: torch.utils.data.DataLoader, model: nn.Module, - token_table: k2.SymbolTable, + token_table: SymbolTable, ) -> Tuple[List[Tuple[str, List[str], List[str]]], float]: """Decode dataset. @@ -181,7 +188,7 @@ def decode_dataset( model: The neural model. token_table: - Mapping ids to tokens. + The token table. Returns: - A list of tuples. Each tuple contains three elements: @@ -211,7 +218,7 @@ def decode_dataset( this_batch = [] assert len(hyps) == len(texts) for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = list(ref_text) + ref_words = ref_text.split() this_batch.append((cut_id, ref_words, hyp_words)) results.extend(this_batch) @@ -256,7 +263,7 @@ def save_results( @torch.no_grad() def main(): parser = get_parser() - WenetSpeechAsrDataModule.add_arguments(parser) + LibriSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() assert ( @@ -270,8 +277,7 @@ def main(): device = torch.device("cpu") logging.info(f"Device: {device}") - token_table = k2.SymbolTable.from_file(args.tokens) - assert token_table[0] == "" + token_table = SymbolTable.from_file(args.tokens) logging.info(vars(args)) @@ -284,37 +290,20 @@ def main(): # we need cut ids to display recognition results. args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) - wenetspeech = WenetSpeechAsrDataModule(args) + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() - def remove_short_utt(c: Cut): - T = ((c.num_frames - 7) // 2 + 1) // 2 - if T <= 0: - logging.warning( - f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}." - ) - return T > 0 + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) - dev_cuts = wenetspeech.valid_cuts() - dev_cuts = dev_cuts.filter(remove_short_utt) - dev_dl = wenetspeech.valid_dataloaders(dev_cuts) - - test_net_cuts = wenetspeech.test_net_cuts() - test_net_cuts = test_net_cuts.filter(remove_short_utt) - test_net_dl = wenetspeech.test_dataloaders(test_net_cuts) - - test_meeting_cuts = wenetspeech.test_meeting_cuts() - test_meeting_cuts = test_meeting_cuts.filter(remove_short_utt) - test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts) - - test_sets = ["DEV", "TEST_NET", "TEST_MEETING"] - test_dl = [dev_dl, test_net_dl, test_meeting_dl] + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] for test_set, test_dl in zip(test_sets, test_dl): start_time = time.time() - results, total_duration = decode_dataset( - dl=test_dl, model=model, token_table=token_table - ) + results, total_duration = decode_dataset(dl=test_dl, model=model, token_table=token_table) end_time = time.time() elapsed_seconds = end_time - start_time rtf = elapsed_seconds / total_duration diff --git a/egs/multi_zh-hans/ASR/zipformer/onnx_pretrained-streaming.py b/egs/multi_zh-hans/ASR/zipformer/onnx_pretrained-streaming.py deleted file mode 120000 index cfea104c2..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/onnx_pretrained-streaming.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_pretrained-streaming.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/onnx_pretrained-streaming.py b/egs/multi_zh-hans/ASR/zipformer/onnx_pretrained-streaming.py new file mode 100755 index 000000000..2ce4506a8 --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/onnx_pretrained-streaming.py @@ -0,0 +1,544 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2023 Danqing Fu (danqing.fu@gmail.com) + +""" +This script loads ONNX models exported by ./export-onnx-streaming.py +and uses them to decode waves. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./zipformer/export-onnx-streaming.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --num-encoder-layers "2,2,3,4,3,2" \ + --downsampling-factor "1,2,4,8,4,2" \ + --feedforward-dim "512,768,1024,1536,1024,768" \ + --num-heads "4,4,4,8,4,4" \ + --encoder-dim "192,256,384,512,384,256" \ + --query-head-dim 32 \ + --value-head-dim 12 \ + --pos-head-dim 4 \ + --pos-dim 48 \ + --encoder-unmasked-dim "192,192,256,256,256,192" \ + --cnn-module-kernel "31,31,15,15,15,31" \ + --decoder-dim 512 \ + --joiner-dim 512 \ + --causal True \ + --chunk-size 16 \ + --left-context-frames 64 + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +3. Run this file with the exported ONNX models + +./zipformer/onnx_pretrained-streaming.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav + +Note: Even though this script only supports decoding a single file, +the exported ONNX models do support batch processing. +""" + +import argparse +import logging +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import onnxruntime as ort +import torch +import torchaudio +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + encoder_model_filename: str, + decoder_model_filename: str, + joiner_model_filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_encoder(encoder_model_filename) + self.init_decoder(decoder_model_filename) + self.init_joiner(joiner_model_filename) + + def init_encoder(self, encoder_model_filename: str): + self.encoder = ort.InferenceSession( + encoder_model_filename, + sess_options=self.session_opts, + ) + self.init_encoder_states() + + def init_encoder_states(self, batch_size: int = 1): + encoder_meta = self.encoder.get_modelmeta().custom_metadata_map + logging.info(f"encoder_meta={encoder_meta}") + + model_type = encoder_meta["model_type"] + assert model_type == "zipformer2", model_type + + decode_chunk_len = int(encoder_meta["decode_chunk_len"]) + T = int(encoder_meta["T"]) + + num_encoder_layers = encoder_meta["num_encoder_layers"] + encoder_dims = encoder_meta["encoder_dims"] + cnn_module_kernels = encoder_meta["cnn_module_kernels"] + left_context_len = encoder_meta["left_context_len"] + query_head_dims = encoder_meta["query_head_dims"] + value_head_dims = encoder_meta["value_head_dims"] + num_heads = encoder_meta["num_heads"] + + def to_int_list(s): + return list(map(int, s.split(","))) + + num_encoder_layers = to_int_list(num_encoder_layers) + encoder_dims = to_int_list(encoder_dims) + cnn_module_kernels = to_int_list(cnn_module_kernels) + left_context_len = to_int_list(left_context_len) + query_head_dims = to_int_list(query_head_dims) + value_head_dims = to_int_list(value_head_dims) + num_heads = to_int_list(num_heads) + + logging.info(f"decode_chunk_len: {decode_chunk_len}") + logging.info(f"T: {T}") + logging.info(f"num_encoder_layers: {num_encoder_layers}") + logging.info(f"encoder_dims: {encoder_dims}") + logging.info(f"cnn_module_kernels: {cnn_module_kernels}") + logging.info(f"left_context_len: {left_context_len}") + logging.info(f"query_head_dims: {query_head_dims}") + logging.info(f"value_head_dims: {value_head_dims}") + logging.info(f"num_heads: {num_heads}") + + num_encoders = len(num_encoder_layers) + + self.states = [] + for i in range(num_encoders): + num_layers = num_encoder_layers[i] + key_dim = query_head_dims[i] * num_heads[i] + embed_dim = encoder_dims[i] + nonlin_attn_head_dim = 3 * embed_dim // 4 + value_dim = value_head_dims[i] * num_heads[i] + conv_left_pad = cnn_module_kernels[i] // 2 + + for layer in range(num_layers): + cached_key = torch.zeros( + left_context_len[i], batch_size, key_dim + ).numpy() + cached_nonlin_attn = torch.zeros( + 1, batch_size, left_context_len[i], nonlin_attn_head_dim + ).numpy() + cached_val1 = torch.zeros( + left_context_len[i], batch_size, value_dim + ).numpy() + cached_val2 = torch.zeros( + left_context_len[i], batch_size, value_dim + ).numpy() + cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy() + cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy() + self.states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + embed_states = torch.zeros(batch_size, 128, 3, 19).numpy() + self.states.append(embed_states) + processed_lens = torch.zeros(batch_size, dtype=torch.int64).numpy() + self.states.append(processed_lens) + + self.num_encoders = num_encoders + + self.segment = T + self.offset = decode_chunk_len + + def init_decoder(self, decoder_model_filename: str): + self.decoder = ort.InferenceSession( + decoder_model_filename, + sess_options=self.session_opts, + ) + + decoder_meta = self.decoder.get_modelmeta().custom_metadata_map + self.context_size = int(decoder_meta["context_size"]) + self.vocab_size = int(decoder_meta["vocab_size"]) + + logging.info(f"context_size: {self.context_size}") + logging.info(f"vocab_size: {self.vocab_size}") + + def init_joiner(self, joiner_model_filename: str): + self.joiner = ort.InferenceSession( + joiner_model_filename, + sess_options=self.session_opts, + ) + + joiner_meta = self.joiner.get_modelmeta().custom_metadata_map + self.joiner_dim = int(joiner_meta["joiner_dim"]) + + logging.info(f"joiner_dim: {self.joiner_dim}") + + def _build_encoder_input_output( + self, + x: torch.Tensor, + ) -> Tuple[Dict[str, np.ndarray], List[str]]: + encoder_input = {"x": x.numpy()} + encoder_output = ["encoder_out"] + + def build_inputs_outputs(tensors, i): + assert len(tensors) == 6, len(tensors) + + # (downsample_left, batch_size, key_dim) + name = f"cached_key_{i}" + encoder_input[name] = tensors[0] + encoder_output.append(f"new_{name}") + + # (1, batch_size, downsample_left, nonlin_attn_head_dim) + name = f"cached_nonlin_attn_{i}" + encoder_input[name] = tensors[1] + encoder_output.append(f"new_{name}") + + # (downsample_left, batch_size, value_dim) + name = f"cached_val1_{i}" + encoder_input[name] = tensors[2] + encoder_output.append(f"new_{name}") + + # (downsample_left, batch_size, value_dim) + name = f"cached_val2_{i}" + encoder_input[name] = tensors[3] + encoder_output.append(f"new_{name}") + + # (batch_size, embed_dim, conv_left_pad) + name = f"cached_conv1_{i}" + encoder_input[name] = tensors[4] + encoder_output.append(f"new_{name}") + + # (batch_size, embed_dim, conv_left_pad) + name = f"cached_conv2_{i}" + encoder_input[name] = tensors[5] + encoder_output.append(f"new_{name}") + + for i in range(len(self.states[:-2]) // 6): + build_inputs_outputs(self.states[i * 6 : (i + 1) * 6], i) + + # (batch_size, channels, left_pad, freq) + name = "embed_states" + embed_states = self.states[-2] + encoder_input[name] = embed_states + encoder_output.append(f"new_{name}") + + # (batch_size,) + name = "processed_lens" + processed_lens = self.states[-1] + encoder_input[name] = processed_lens + encoder_output.append(f"new_{name}") + + return encoder_input, encoder_output + + def _update_states(self, states: List[np.ndarray]): + self.states = states + + def run_encoder(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C) + Returns: + Return a 3-D tensor of shape (N, T', joiner_dim) where + T' is usually equal to ((T-7)//2+1)//2 + """ + encoder_input, encoder_output_names = self._build_encoder_input_output(x) + + out = self.encoder.run(encoder_output_names, encoder_input) + + self._update_states(out[1:]) + + return torch.from_numpy(out[0]) + + def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor: + """ + Args: + decoder_input: + A 2-D tensor of shape (N, context_size) + Returns: + Return a 2-D tensor of shape (N, joiner_dim) + """ + out = self.decoder.run( + [self.decoder.get_outputs()[0].name], + {self.decoder.get_inputs()[0].name: decoder_input.numpy()}, + )[0] + + return torch.from_numpy(out) + + def run_joiner( + self, encoder_out: torch.Tensor, decoder_out: torch.Tensor + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + out = self.joiner.run( + [self.joiner.get_outputs()[0].name], + { + self.joiner.get_inputs()[0].name: encoder_out.numpy(), + self.joiner.get_inputs()[1].name: decoder_out.numpy(), + }, + )[0] + + return torch.from_numpy(out) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def create_streaming_feature_extractor() -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + return OnlineFbank(opts) + + +def greedy_search( + model: OnnxModel, + encoder_out: torch.Tensor, + context_size: int, + decoder_out: Optional[torch.Tensor] = None, + hyp: Optional[List[int]] = None, +) -> List[int]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (1, T, joiner_dim) + context_size: + The context size of the decoder model. + decoder_out: + Optional. Decoder output of the previous chunk. + hyp: + Decoding results for previous chunks. + Returns: + Return the decoded results so far. + """ + + blank_id = 0 + + if decoder_out is None: + assert hyp is None, hyp + hyp = [blank_id] * context_size + decoder_input = torch.tensor([hyp], dtype=torch.int64) + decoder_out = model.run_decoder(decoder_input) + else: + assert hyp is not None, hyp + + encoder_out = encoder_out.squeeze(0) + T = encoder_out.size(0) + for t in range(T): + cur_encoder_out = encoder_out[t : t + 1] + joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0) + y = joiner_out.argmax(dim=0).item() + if y != blank_id: + hyp.append(y) + decoder_input = hyp[-context_size:] + decoder_input = torch.tensor([decoder_input], dtype=torch.int64) + decoder_out = model.run_decoder(decoder_input) + + return hyp, decoder_out + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + online_fbank = create_streaming_feature_extractor() + + logging.info(f"Reading sound files: {args.sound_file}") + waves = read_sound_files( + filenames=[args.sound_file], + expected_sample_rate=sample_rate, + )[0] + + tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32) + wave_samples = torch.cat([waves, tail_padding]) + + num_processed_frames = 0 + segment = model.segment + offset = model.offset + + context_size = model.context_size + hyp = None + decoder_out = None + + chunk = int(1 * sample_rate) # 1 second + start = 0 + while start < wave_samples.numel(): + end = min(start + chunk, wave_samples.numel()) + samples = wave_samples[start:end] + start += chunk + + online_fbank.accept_waveform( + sampling_rate=sample_rate, + waveform=samples, + ) + + while online_fbank.num_frames_ready - num_processed_frames >= segment: + frames = [] + for i in range(segment): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + num_processed_frames += offset + frames = torch.cat(frames, dim=0) + frames = frames.unsqueeze(0) + encoder_out = model.run_encoder(frames) + hyp, decoder_out = greedy_search( + model, + encoder_out, + context_size, + decoder_out, + hyp, + ) + + token_table = k2.SymbolTable.from_file(args.tokens) + + text = "" + for i in hyp[context_size:]: + text += token_table[i] + text = text.replace("▁", " ").strip() + + logging.info(args.sound_file) + logging.info(text) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/multi_zh-hans/ASR/zipformer/onnx_pretrained.py b/egs/multi_zh-hans/ASR/zipformer/onnx_pretrained.py deleted file mode 120000 index 8f32f4ee7..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/onnx_pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/onnx_pretrained.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/onnx_pretrained.py b/egs/multi_zh-hans/ASR/zipformer/onnx_pretrained.py new file mode 100755 index 000000000..e8a521460 --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/onnx_pretrained.py @@ -0,0 +1,419 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ONNX models and uses them to decode waves. +You can use the following command to get the exported models: + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/tokens.txt" +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./zipformer/export-onnx.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --causal False + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +3. Run this file + +./zipformer/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +""" + +import argparse +import logging +import math +from typing import List, Tuple + +import k2 +import kaldifeat +import onnxruntime as ort +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + encoder_model_filename: str, + decoder_model_filename: str, + joiner_model_filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 4 + + self.session_opts = session_opts + + self.init_encoder(encoder_model_filename) + self.init_decoder(decoder_model_filename) + self.init_joiner(joiner_model_filename) + + def init_encoder(self, encoder_model_filename: str): + self.encoder = ort.InferenceSession( + encoder_model_filename, + sess_options=self.session_opts, + ) + + def init_decoder(self, decoder_model_filename: str): + self.decoder = ort.InferenceSession( + decoder_model_filename, + sess_options=self.session_opts, + ) + + decoder_meta = self.decoder.get_modelmeta().custom_metadata_map + self.context_size = int(decoder_meta["context_size"]) + self.vocab_size = int(decoder_meta["vocab_size"]) + + logging.info(f"context_size: {self.context_size}") + logging.info(f"vocab_size: {self.vocab_size}") + + def init_joiner(self, joiner_model_filename: str): + self.joiner = ort.InferenceSession( + joiner_model_filename, + sess_options=self.session_opts, + ) + + joiner_meta = self.joiner.get_modelmeta().custom_metadata_map + self.joiner_dim = int(joiner_meta["joiner_dim"]) + + logging.info(f"joiner_dim: {self.joiner_dim}") + + def run_encoder( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 2-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - encoder_out, its shape is (N, T', joiner_dim) + - encoder_out_lens, its shape is (N,) + """ + out = self.encoder.run( + [ + self.encoder.get_outputs()[0].name, + self.encoder.get_outputs()[1].name, + ], + { + self.encoder.get_inputs()[0].name: x.numpy(), + self.encoder.get_inputs()[1].name: x_lens.numpy(), + }, + ) + return torch.from_numpy(out[0]), torch.from_numpy(out[1]) + + def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor: + """ + Args: + decoder_input: + A 2-D tensor of shape (N, context_size) + Returns: + Return a 2-D tensor of shape (N, joiner_dim) + """ + out = self.decoder.run( + [self.decoder.get_outputs()[0].name], + {self.decoder.get_inputs()[0].name: decoder_input.numpy()}, + )[0] + + return torch.from_numpy(out) + + def run_joiner( + self, encoder_out: torch.Tensor, decoder_out: torch.Tensor + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + out = self.joiner.run( + [self.joiner.get_outputs()[0].name], + { + self.joiner.get_inputs()[0].name: encoder_out.numpy(), + self.joiner.get_inputs()[1].name: decoder_out.numpy(), + }, + )[0] + + return torch.from_numpy(out) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + model: OnnxModel, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (N, T, joiner_dim) + encoder_out_lens: + A 1-D tensor of shape (N,). + Returns: + Return the decoded results for each utterance. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = 0 # hard-code to 0 + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + context_size = model.context_size + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = model.run_decoder(decoder_input) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = packed_encoder_out.data[start:end] + # current_encoder_out's shape: (batch_size, joiner_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + logits = model.run_joiner(current_encoder_out, decoder_out) + + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + dtype=torch.int64, + ) + decoder_out = model.run_decoder(decoder_input) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) + encoder_out, encoder_out_lens = model.run_encoder(features, feature_lengths) + + hyps = greedy_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + s = "\n" + + token_table = k2.SymbolTable.from_file(args.tokens) + + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + + for filename, hyp in zip(args.sound_files, hyps): + words = token_ids_to_words(hyp) + s += f"{filename}:\n{words}\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/multi_zh-hans/ASR/zipformer/optim.py b/egs/multi_zh-hans/ASR/zipformer/optim.py deleted file mode 120000 index 5eaa3cffd..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/optim.py b/egs/multi_zh-hans/ASR/zipformer/optim.py new file mode 100644 index 000000000..abfb2092c --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/optim.py @@ -0,0 +1,1173 @@ +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import logging +import random +from collections import defaultdict +from typing import Dict, List, Optional, Tuple, Union + +import torch +from lhotse.utils import fix_random_seed +from torch import Tensor +from torch.optim import Optimizer + + +class BatchedOptimizer(Optimizer): + """ + This class adds to class Optimizer the capability to optimize parameters in batches: + it will stack the parameters and their grads for you so the optimizer can work + on tensors with an extra leading dimension. This is intended for speed with GPUs, + as it reduces the number of kernels launched in the optimizer. + + Args: + params: + """ + + def __init__(self, params, defaults): + super(BatchedOptimizer, self).__init__(params, defaults) + + @contextlib.contextmanager + def batched_params(self, param_group, group_params_names): + """ + This function returns (technically, yields) a list of + of tuples (p, state), where + p is a `fake` parameter that is stacked (over axis 0) from real parameters + that share the same shape, and its gradient is also stacked; + `state` is the state corresponding to this batch of parameters + (it will be physically located in the "state" for one of the real + parameters, the last one that has any particular shape and dtype). + + This function is decorated as a context manager so that it can + write parameters back to their "real" locations. + + The idea is, instead of doing: + + for p in group["params"]: + state = self.state[p] + ... + + you can do: + + with self.batched_params(group["params"]) as batches: + for p, state, p_names in batches: + ... + + + Args: + group: a parameter group, which is a list of parameters; should be + one of self.param_groups. + group_params_names: name for each parameter in group, + which is List[str]. + """ + batches = defaultdict( + list + ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter + batches_names = defaultdict( + list + ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str + + assert len(param_group) == len(group_params_names) + for p, named_p in zip(param_group, group_params_names): + key = (str(p.dtype), *p.shape) + batches[key].append(p) + batches_names[key].append(named_p) + + batches_names_keys = list(batches_names.keys()) + sorted_idx = sorted( + range(len(batches_names)), key=lambda i: batches_names_keys[i] + ) + batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx] + batches = [batches[batches_names_keys[idx]] for idx in sorted_idx] + + stacked_params_dict = dict() + + # turn batches into a list, in deterministic order. + # tuples will contain tuples of (stacked_param, state, stacked_params_names), + # one for each batch in `batches`. + tuples = [] + + for batch, batch_names in zip(batches, batches_names): + p = batch[0] + # we arbitrarily store the state in the + # state corresponding to the 1st parameter in the + # group. class Optimizer will take care of saving/loading state. + state = self.state[p] + p_stacked = torch.stack(batch) + grad = torch.stack( + [torch.zeros_like(p) if p.grad is None else p.grad for p in batch] + ) + p_stacked.grad = grad + stacked_params_dict[key] = p_stacked + tuples.append((p_stacked, state, batch_names)) + + yield tuples # <-- calling code will do the actual optimization here! + + for ((stacked_params, _state, _names), batch) in zip(tuples, batches): + for i, p in enumerate(batch): # batch is list of Parameter + p.copy_(stacked_params[i]) + + +class ScaledAdam(BatchedOptimizer): + """ + Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update + proportional to the norm of that parameter; and also learn the scale of the parameter, + in log space, subject to upper and lower limits (as if we had factored each parameter as + param = underlying_param * log_scale.exp()) + + + Args: + params: The parameters or param_groups to optimize (like other Optimizer subclasses) + Unlike common optimizers, which accept model.parameters() or groups of parameters(), + this optimizer could accept model.named_parameters() or groups of named_parameters(). + See comments of function _get_names_of_parameters for its 4 possible cases. + lr: The learning rate. We will typically use a learning rate schedule that starts + at 0.03 and decreases over time, i.e. much higher than other common + optimizers. + clipping_scale: (e.g. 2.0) + A scale for gradient-clipping: if specified, the normalized gradients + over the whole model will be clipped to have 2-norm equal to + `clipping_scale` times the median 2-norm over the most recent period + of `clipping_update_period` minibatches. By "normalized gradients", + we mean after multiplying by the rms parameter value for this tensor + [for non-scalars]; this is appropriate because our update is scaled + by this quantity. + betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. + Must satisfy 0 < beta <= beta2 < 1. + scalar_lr_scale: A scaling factor on the learning rate, that we use to update the + scale of each parameter tensor and scalar parameters of the mode.. + If each parameter were decomposed + as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale + would be a the scaling factor on the learning rate of p_scale. + eps: A general-purpose epsilon to prevent division by zero + param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be >= this value) + param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be <= this value) + scalar_max: Maximum absolute value for scalar parameters (applicable if your + model has any parameters with numel() == 1). + size_update_period: The periodicity, in steps, with which we update the size (scale) + of the parameter tensor. This is provided to save a little time + in the update. + clipping_update_period: if clipping_scale is specified, this is the period + """ + + def __init__( + self, + params, + lr=3e-02, + clipping_scale=None, + betas=(0.9, 0.98), + scalar_lr_scale=0.1, + eps=1.0e-08, + param_min_rms=1.0e-05, + param_max_rms=3.0, + scalar_max=10.0, + size_update_period=4, + clipping_update_period=100, + ): + + defaults = dict( + lr=lr, + clipping_scale=clipping_scale, + betas=betas, + scalar_lr_scale=scalar_lr_scale, + eps=eps, + param_min_rms=param_min_rms, + param_max_rms=param_max_rms, + scalar_max=scalar_max, + size_update_period=size_update_period, + clipping_update_period=clipping_update_period, + ) + + # If params only contains parameters or group of parameters, + # i.e when parameter names are not given, + # this flag will be set to False in funciton _get_names_of_parameters. + self.show_dominant_parameters = True + param_groups, parameters_names = self._get_names_of_parameters(params) + super(ScaledAdam, self).__init__(param_groups, defaults) + assert len(self.param_groups) == len(parameters_names) + self.parameters_names = parameters_names + + def _get_names_of_parameters( + self, params_or_named_params + ) -> Tuple[List[Dict], List[List[str]]]: + """ + Args: + params_or_named_params: according to the way ScaledAdam is initialized in train.py, + this argument could be one of following 4 cases, + case 1, a generator of parameter, e.g.: + optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=3.0) + + case 2, a list of parameter groups with different config, e.g.: + model_param_groups = [ + {'params': model.encoder.parameters(), 'lr': 0.05}, + {'params': model.decoder.parameters(), 'lr': 0.01}, + {'params': model.joiner.parameters(), 'lr': 0.03}, + ] + optimizer = ScaledAdam(model_param_groups, lr=params.base_lr, clipping_scale=3.0) + + case 3, a generator of named_parameter, e.g.: + optimizer = ScaledAdam(model.named_parameters(), lr=params.base_lr, clipping_scale=3.0) + + case 4, a list of named_parameter groups with different config, e.g.: + model_named_param_groups = [ + {'named_params': model.encoder.named_parameters(), 'lr': 0.05}, + {'named_params': model.decoder.named_parameters(), 'lr': 0.01}, + {'named_params': model.joiner.named_parameters(), 'lr': 0.03}, + ] + optimizer = ScaledAdam(model_named_param_groups, lr=params.base_lr, clipping_scale=3.0) + + For case 1 and case 2, input params is used to initialize the underlying torch.optimizer. + For case 3 and case 4, firstly, names and params are extracted from input named_params, + then, these extracted params are used to initialize the underlying torch.optimizer, + and these extracted names are mainly used by function + `_show_gradient_dominating_parameter` + + Returns: + Returns a tuple containing 2 elements: + - `param_groups` with type List[Dict], each Dict element is a parameter group. + An example of `param_groups` could be: + [ + {'params': `one iterable of Parameter`, 'lr': 0.05}, + {'params': `another iterable of Parameter`, 'lr': 0.08}, + {'params': `a third iterable of Parameter`, 'lr': 0.1}, + ] + - `param_gruops_names` with type List[List[str]], + each `List[str]` is for a group['params'] in param_groups, + and each `str` is the name of a parameter. + A dummy name "foo" is related to each parameter, + if input are params without names, i.e. case 1 or case 2. + """ + # variable naming convention in this function: + # p is short for param. + # np is short for named_param. + # p_or_np is short for param_or_named_param. + # cur is short for current. + # group is a dict, e.g. {'params': iterable of parameter, 'lr': 0.05, other fields}. + # groups is a List[group] + + iterable_or_groups = list(params_or_named_params) + if len(iterable_or_groups) == 0: + raise ValueError("optimizer got an empty parameter list") + + # The first value of returned tuple. A list of dicts containing at + # least 'params' as a key. + param_groups = [] + + # The second value of returned tuple, + # a List[List[str]], each sub-List is for a group. + param_groups_names = [] + + if not isinstance(iterable_or_groups[0], dict): + # case 1 or case 3, + # the input is an iterable of parameter or named parameter. + param_iterable_cur_group = [] + param_names_cur_group = [] + for p_or_np in iterable_or_groups: + if isinstance(p_or_np, tuple): + # case 3 + name, param = p_or_np + else: + # case 1 + assert isinstance(p_or_np, torch.Tensor) + param = p_or_np + # Assign a dummy name as a placeholder + name = "foo" + self.show_dominant_parameters = False + param_iterable_cur_group.append(param) + param_names_cur_group.append(name) + param_groups.append({"params": param_iterable_cur_group}) + param_groups_names.append(param_names_cur_group) + else: + # case 2 or case 4 + # the input is groups of parameter or named parameter. + for cur_group in iterable_or_groups: + assert "named_params" in cur_group + name_list = [ x[0] for x in cur_group["named_params"] ] + p_list = [ x[1] for x in cur_group["named_params"] ] + del cur_group["named_params"] + cur_group["params"] = p_list + param_groups.append(cur_group) + param_groups_names.append(name_list) + + return param_groups, param_groups_names + + def __setstate__(self, state): + super(ScaledAdam, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + batch = True + + for group, group_params_names in zip(self.param_groups, self.parameters_names): + + with self.batched_params(group["params"], group_params_names) as batches: + + # batches is list of pairs (stacked_param, state). stacked_param is like + # a regular parameter, and will have a .grad, but the 1st dim corresponds to + # a stacking dim, it is not a real dim. + + if ( + len(batches[0][1]) == 0 + ): # if len(first state) == 0: not yet initialized + clipping_scale = 1 + else: + clipping_scale = self._get_clipping_scale(group, batches) + + for p, state, _ in batches: + # Perform optimization step. + # grad is not going to be None, we handled that when creating the batches. + grad = p.grad + if grad.is_sparse: + raise RuntimeError( + "ScaledAdam optimizer does not support sparse gradients" + ) + # State initialization + if len(state) == 0: + self._init_state(group, p, state) + + self._step_one_batch(group, p, state, clipping_scale) + + return loss + + def _init_state(self, group: dict, p: Tensor, state: dict): + """ + Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p + is actually the batch dimension, corresponding to batched-together + parameters of a given shape. + + + Args: + group: Dict to look up configuration values. + p: The parameter that we are initializing the state for + state: Dict from string to whatever state we are initializing + """ + size_update_period = group["size_update_period"] + + state["step"] = 0 + + kwargs = {"device": p.device, "dtype": p.dtype} + + # 'delta' implements conventional momentum. There are + # several different kinds of update going on, so rather than + # compute "exp_avg" like in Adam, we store and decay a + # parameter-change "delta", which combines all forms of + # update. this is equivalent to how it's done in Adam, + # except for the first few steps. + state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format) + + batch_size = p.shape[0] + numel = p.numel() // batch_size + + if numel > 1: + # "param_rms" just periodically records the scalar root-mean-square value of + # the parameter tensor. + # it has a shape like (batch_size, 1, 1, 1, 1) + param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() + state["param_rms"] = param_rms + + state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) + state["scale_grads"] = torch.zeros( + size_update_period, *param_rms.shape, **kwargs + ) + + # exp_avg_sq is the weighted sum of scaled gradients. as in Adam. + state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) + + def _get_clipping_scale( + self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]] + ) -> float: + """ + Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients + by this amount before applying the rest of the update. + + Args: + group: the parameter group, an item in self.param_groups + tuples: a list of tuples of (param, state, param_names) + where param is a batched set of parameters, + with a .grad (1st dim is batch dim) + and state is the state-dict where optimization parameters are kept. + param_names is a List[str] while each str is name for a parameter + in batched set of parameters "param". + """ + assert len(tuples) >= 1 + clipping_scale = group["clipping_scale"] + (first_p, first_state, _) = tuples[0] + step = first_state["step"] + if clipping_scale is None or step == 0: + # no clipping. return early on step == 0 because the other + # parameters' state won't have been initialized yet. + return 1.0 + clipping_update_period = group["clipping_update_period"] + + tot_sumsq = torch.tensor(0.0, device=first_p.device) + for (p, state, param_names) in tuples: + grad = p.grad + if grad.is_sparse: + raise RuntimeError( + "ScaledAdam optimizer does not support sparse gradients" + ) + if p.numel() == p.shape[0]: # a batch of scalars + tot_sumsq += (grad**2).sum() # sum() to change shape [1] to [] + else: + tot_sumsq += ((grad * state["param_rms"]) ** 2).sum() + + tot_norm = tot_sumsq.sqrt() + if "model_norms" not in first_state: + first_state["model_norms"] = torch.zeros( + clipping_update_period, device=p.device + ) + first_state["model_norms"][step % clipping_update_period] = tot_norm + + if step % clipping_update_period == 0: + # Print some stats. + # We don't reach here if step == 0 because we would have returned + # above. + sorted_norms = first_state["model_norms"].sort()[0].to("cpu") + quartiles = [] + for n in range(0, 5): + index = min( + clipping_update_period - 1, (clipping_update_period // 4) * n + ) + quartiles.append(sorted_norms[index].item()) + + median = quartiles[2] + threshold = clipping_scale * median + first_state["model_norm_threshold"] = threshold + percent_clipped = ( + first_state["num_clipped"] * 100.0 / clipping_update_period + if "num_clipped" in first_state + else 0.0 + ) + first_state["num_clipped"] = 0 + quartiles = " ".join(["%.3e" % x for x in quartiles]) + logging.info( + f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " + f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}" + ) + + if step < clipping_update_period: + return 1.0 # We have not yet estimated a norm to clip to. + else: + try: + model_norm_threshold = first_state["model_norm_threshold"] + except KeyError: + logging.info( + "Warning: model_norm_threshold not in state: possibly " + "you changed config when restarting, adding clipping_scale option?" + ) + return 1.0 + ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item()) + if ans < 1.0: + first_state["num_clipped"] += 1 + if ans < 0.1: + logging.warn( + f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}" + ) + if self.show_dominant_parameters: + assert p.shape[0] == len(param_names) + self._show_gradient_dominating_parameter(tuples, tot_sumsq) + return ans + + def _show_gradient_dominating_parameter( + self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor + ): + """ + Show information of parameter which dominates tot_sumsq. + + Args: + tuples: a list of tuples of (param, state, param_names) + where param is a batched set of parameters, + with a .grad (1st dim is batch dim) + and state is the state-dict where optimization parameters are kept. + param_names is a List[str] while each str is name for a parameter + in batched set of parameters "param". + tot_sumsq: sumsq of all parameters. Though it's could be calculated + from tuples, we still pass it to save some time. + """ + all_sumsq_orig = {} + for (p, state, batch_param_names) in tuples: + # p is a stacked batch parameters. + batch_grad = p.grad + if p.numel() == p.shape[0]: # a batch of scalars + batch_sumsq_orig = batch_grad**2 + # Dummy values used by following `zip` statement. + batch_rms_orig = torch.ones(p.shape[0]) + else: + batch_rms_orig = state["param_rms"] + batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum( + dim=list(range(1, batch_grad.ndim)) + ) + + for name, sumsq_orig, rms, grad in zip( + batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad + ): + + proportion_orig = sumsq_orig / tot_sumsq + all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad) + + assert torch.isclose( + sum([value[0] for value in all_sumsq_orig.values()]).cpu(), + torch.tensor(1.0), + ) + sorted_by_proportion = { + k: v + for k, v in sorted( + all_sumsq_orig.items(), key=lambda item: item[1][0], reverse=True + ) + } + dominant_param_name = next(iter(sorted_by_proportion)) + ( + dominant_proportion, + dominant_sumsq, + dominant_rms, + dominant_grad, + ) = sorted_by_proportion[dominant_param_name] + logging.info( + f"Parameter dominating tot_sumsq {dominant_param_name}" + f" with proportion {dominant_proportion:.2f}," + f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)" + f"={dominant_sumsq:.3e}," + f" grad_sumsq={(dominant_grad**2).sum():.3e}," + f" orig_rms_sq={(dominant_rms**2).item():.3e}" + ) + + def _step_one_batch( + self, group: dict, p: Tensor, state: dict, clipping_scale: float + ): + """ + Do the step for one parameter, which is actually going to be a batch of + `real` parameters, with dim 0 as the batch dim. + Args: + group: dict to look up configuration values + p: parameter to update (actually multiple parameters stacked together + as a batch) + state: state-dict for p, to look up the optimizer state + """ + lr = group["lr"] + size_update_period = group["size_update_period"] + beta1 = group["betas"][0] + + grad = p.grad + if clipping_scale != 1.0: + grad = grad * clipping_scale + step = state["step"] + delta = state["delta"] + + delta.mul_(beta1) + batch_size = p.shape[0] + numel = p.numel() // batch_size + if numel > 1: + # Update the size/scale of p, and set param_rms + scale_grads = state["scale_grads"] + scale_grads[step % size_update_period] = (p * grad).sum( + dim=list(range(1, p.ndim)), keepdim=True + ) + if step % size_update_period == size_update_period - 1: + param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..) + param_rms.copy_( + (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() + ) + if step > 0: + # self._size_update() learns the overall scale on the + # parameter, by shrinking or expanding it. + self._size_update(group, scale_grads, p, state) + + if numel == 1: + # For parameters with 1 element we just use regular Adam. + # Updates delta. + self._step_scalar(group, p, state) + else: + self._step(group, p, state) + + state["step"] = step + 1 + + def _size_update( + self, group: dict, scale_grads: Tensor, p: Tensor, state: dict + ) -> None: + """ + Called only where p.numel() > 1, this updates the scale of the parameter. + If we imagine: p = underlying_param * scale.exp(), and we are doing + gradient descent on underlying param and on scale, this function does the update + on `scale`. + + Args: + group: dict to look up configuration values + scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing + grads w.r.t. the scales. + p: The parameter to update + state: The state-dict of p + """ + + param_rms = state["param_rms"] + beta1, beta2 = group["betas"] + size_lr = group["lr"] * group["scalar_lr_scale"] + param_min_rms = group["param_min_rms"] + param_max_rms = group["param_max_rms"] + eps = group["eps"] + step = state["step"] + batch_size = p.shape[0] + + size_update_period = scale_grads.shape[0] + # correct beta2 for the size update period: we will have + # faster decay at this level. + beta2_corr = beta2**size_update_period + + scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..) + scale_exp_avg_sq.mul_(beta2_corr).add_( + (scale_grads**2).mean(dim=0), # mean over dim `size_update_period` + alpha=1 - beta2_corr, + ) # shape is (batch_size, 1, 1, ...) + + # The 1st time we reach here is when size_step == 1. + size_step = (step + 1) // size_update_period + bias_correction2 = 1 - beta2_corr**size_step + # we don't bother with bias_correction1; this will help prevent divergence + # at the start of training. + + denom = scale_exp_avg_sq.sqrt() + eps + + scale_step = ( + -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom + ) + + is_too_small = param_rms < param_min_rms + + # when the param gets too small, just don't shrink it any further. + scale_step.masked_fill_(is_too_small, 0.0) + + # and ensure the parameter rms after update never exceeds param_max_rms. + # We have to look at the trained model for parameters at or around the + # param_max_rms, because sometimes they can indicate a problem with the + # topology or settings. + scale_step = torch.minimum(scale_step, + (param_max_rms - param_rms) / param_rms) + + delta = state["delta"] + # the factor of (1-beta1) relates to momentum. + delta.add_(p * scale_step, alpha=(1 - beta1)) + + def _step(self, group: dict, p: Tensor, state: dict): + """ + This function does the core update of self.step(), in the case where the members of + the batch have more than 1 element. + + Args: + group: A dict which will be used to look up configuration values + p: The parameter to be updated + grad: The grad of p + state: The state-dict corresponding to parameter p + + This function modifies p. + """ + grad = p.grad + lr = group["lr"] + beta1, beta2 = group["betas"] + eps = group["eps"] + param_min_rms = group["param_min_rms"] + step = state["step"] + + exp_avg_sq = state["exp_avg_sq"] + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) + + this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0) + bias_correction2 = 1 - beta2 ** (this_step + 1) + if bias_correction2 < 0.99: + # note: not in-place. + exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2) + + denom = exp_avg_sq.sqrt() + denom += eps + grad = grad / denom + + alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms) + + delta = state["delta"] + delta.add_(grad * alpha) + p.add_(delta) + + def _step_scalar(self, group: dict, p: Tensor, state: dict): + """ + A simplified form of the core update for scalar tensors, where we cannot get a good + estimate of the parameter rms. + """ + beta1, beta2 = group["betas"] + scalar_max = group["scalar_max"] + eps = group["eps"] + lr = group["lr"] * group["scalar_lr_scale"] + grad = p.grad + + exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # bias_correction2 is like in Adam. Don't bother with bias_correction1; + # slower update at the start will help stability anyway. + bias_correction2 = 1 - beta2 ** (state["step"] + 1) + denom = (exp_avg_sq / bias_correction2).sqrt() + eps + + delta = state["delta"] + delta.add_(grad / denom, alpha=-lr * (1 - beta1)) + p.clamp_(min=-scalar_max, max=scalar_max) + p.add_(delta) + + +class LRScheduler(object): + """ + Base-class for learning rate schedulers where the learning-rate depends on both the + batch and the epoch. + """ + + def __init__(self, optimizer: Optimizer, verbose: bool = False): + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) + self.optimizer = optimizer + self.verbose = verbose + + for group in optimizer.param_groups: + group.setdefault("base_lr", group["lr"]) + + self.base_lrs = [group["base_lr"] for group in optimizer.param_groups] + + self.epoch = 0 + self.batch = 0 + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + """ + return { + "base_lrs": self.base_lrs, + "epoch": self.epoch, + "batch": self.batch, + } + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def get_last_lr(self) -> List[float]: + """Return last computed learning rate by current scheduler. Will be a list of float.""" + return self._last_lr + + def get_lr(self): + # Compute list of learning rates from self.epoch and self.batch and + # self.base_lrs; this must be overloaded by the user. + # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ] + raise NotImplementedError + + def step_batch(self, batch: Optional[int] = None) -> None: + # Step the batch index, or just set it. If `batch` is specified, it + # must be the batch index from the start of training, i.e. summed over + # all epochs. + # You can call this in any order; if you don't provide 'batch', it should + # of course be called once per batch. + if batch is not None: + self.batch = batch + else: + self.batch = self.batch + 1 + self._set_lrs() + + def step_epoch(self, epoch: Optional[int] = None): + # Step the epoch index, or just set it. If you provide the 'epoch' arg, + # you should call this at the start of the epoch; if you don't provide the 'epoch' + # arg, you should call it at the end of the epoch. + if epoch is not None: + self.epoch = epoch + else: + self.epoch = self.epoch + 1 + self._set_lrs() + + def _set_lrs(self): + values = self.get_lr() + assert len(values) == len(self.optimizer.param_groups) + + for i, data in enumerate(zip(self.optimizer.param_groups, values)): + param_group, lr = data + param_group["lr"] = lr + self.print_lr(self.verbose, i, lr) + self._last_lr = [group["lr"] for group in self.optimizer.param_groups] + + def print_lr(self, is_verbose, group, lr): + """Display the current learning rate.""" + if is_verbose: + logging.info( + f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate" + f" of group {group} to {lr:.4e}." + ) + + +class Eden(LRScheduler): + """ + Eden scheduler. + The basic formula (before warmup) is: + lr = base_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 * + (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) * warmup + where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches + and then stays constant at 1. + + + E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam + + Args: + optimizer: the optimizer to change the learning rates on + lr_batches: the number of batches after which we start significantly + decreasing the learning rate, suggest 5000. + lr_epochs: the number of epochs after which we start significantly + decreasing the learning rate, suggest 6 if you plan to do e.g. + 20 to 40 epochs, but may need smaller number if dataset is huge + and you will do few epochs. + """ + + def __init__( + self, + optimizer: Optimizer, + lr_batches: Union[int, float], + lr_epochs: Union[int, float], + warmup_batches: Union[int, float] = 500.0, + warmup_start: float = 0.5, + verbose: bool = False, + ): + super(Eden, self).__init__(optimizer, verbose) + self.lr_batches = lr_batches + self.lr_epochs = lr_epochs + self.warmup_batches = warmup_batches + + assert 0.0 <= warmup_start <= 1.0, warmup_start + self.warmup_start = warmup_start + + def get_lr(self): + factor = ( + (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 + ) ** -0.25 * ( + ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 + ) + warmup_factor = ( + 1.0 + if self.batch >= self.warmup_batches + else self.warmup_start + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches) + # else 0.5 + 0.5 * (self.batch / self.warmup_batches) + ) + + return [x * factor * warmup_factor for x in self.base_lrs] + + +def _test_eden(): + m = torch.nn.Linear(100, 100) + optim = ScaledAdam(m.parameters(), lr=0.03) + + scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True) + + for epoch in range(10): + scheduler.step_epoch(epoch) # sets epoch to `epoch` + + for step in range(20): + x = torch.randn(200, 100).detach() + x.requires_grad = True + y = m(x) + dy = torch.randn(200, 100).detach() + f = (y * dy).sum() + f.backward() + + optim.step() + scheduler.step_batch() + optim.zero_grad() + + logging.info(f"last lr = {scheduler.get_last_lr()}") + logging.info(f"state dict = {scheduler.state_dict()}") + + +# This is included mostly as a baseline for ScaledAdam. +class Eve(Optimizer): + """ + Implements Eve algorithm. This is a modified version of AdamW with a special + way of setting the weight-decay / shrinkage-factor, which is designed to make the + rms of the parameters approach a particular target_rms (default: 0.1). This is + for use with networks with 'scaled' versions of modules (see scaling.py), which + will be close to invariant to the absolute scale on the parameter matrix. + + The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. + The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. + Eve is unpublished so far. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 3e-4; + this value means that the weight would decay significantly after + about 3k minibatches. Is not multiplied by learning rate, but + is conditional on RMS-value of parameter being > target_rms. + target_rms (float, optional): target root-mean-square value of + parameters, if they fall below this we will stop applying weight decay. + + + .. _Adam: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.98), + eps=1e-8, + weight_decay=1e-3, + target_rms=0.1, + ): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0 <= weight_decay <= 0.1: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0 < target_rms <= 10.0: + raise ValueError("Invalid target_rms value: {}".format(target_rms)) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + target_rms=target_rms, + ) + super(Eve, self).__init__(params, defaults) + + def __setstate__(self, state): + super(Eve, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + # Perform optimization step + grad = p.grad + if grad.is_sparse: + raise RuntimeError("AdamW does not support sparse gradients") + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + + beta1, beta2 = group["betas"] + + state["step"] += 1 + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( + group["eps"] + ) + + step_size = group["lr"] / bias_correction1 + target_rms = group["target_rms"] + weight_decay = group["weight_decay"] + + if p.numel() > 1: + # avoid applying this weight-decay on "scaling factors" + # (which are scalar). + is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) + p.mul_(1 - (weight_decay * is_above_target_rms)) + + p.addcdiv_(exp_avg, denom, value=-step_size) + + if random.random() < 0.0005: + step = (exp_avg / denom) * step_size + logging.info( + f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}" + ) + + return loss + + +def _test_scaled_adam(hidden_dim: int): + import timeit + + from scaling import ScaledLinear + + E = 100 + B = 4 + T = 2 + logging.info("in test_eve_cain") + # device = torch.device('cuda') + device = torch.device("cpu") + dtype = torch.float32 + + fix_random_seed(42) + # these input_magnitudes and output_magnitudes are to test that + # Abel is working as we expect and is able to adjust scales of + # different dims differently. + input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() + output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() + + for iter in [1, 0]: + fix_random_seed(42) + Linear = torch.nn.Linear if iter == 0 else ScaledLinear + + m = torch.nn.Sequential( + Linear(E, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, E), + ).to(device) + + train_pairs = [ + ( + 100.0 + * torch.randn(B, T, E, device=device, dtype=dtype) + * input_magnitudes, + torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes, + ) + for _ in range(20) + ] + + if iter == 0: + optim = Eve(m.parameters(), lr=0.003) + elif iter == 1: + optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) + scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) + + start = timeit.default_timer() + avg_loss = 0.0 + for epoch in range(180): + scheduler.step_epoch() + # if epoch == 100 and iter in [2,3]: + # optim.reset_speedup() # check it doesn't crash. + + # if epoch == 130: + # opts = diagnostics.TensorDiagnosticOptions( + # 2 ** 22 + # ) # allow 4 megabytes per sub-module + # diagnostic = diagnostics.attach_diagnostics(m, opts) + + for n, (x, y) in enumerate(train_pairs): + y_out = m(x) + loss = ((y_out - y) ** 2).mean() * 100.0 + if epoch == 0 and n == 0: + avg_loss = loss.item() + else: + avg_loss = 0.98 * avg_loss + 0.02 * loss.item() + if n == 0 and epoch % 5 == 0: + # norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() + # norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() + # norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() + # norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() + # scale1 = '%.2e' % (m[0].weight_scale.exp().item()) + # scale1b = '%.2e' % (m[0].bias_scale.exp().item()) + # scale2 = '%.2e' % (m[2].weight_scale.exp().item()) + # scale2b = '%.2e' % (m[2].bias_scale.exp().item()) + lr = scheduler.get_last_lr()[0] + logging.info( + f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}" + ) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} + loss.log().backward() + optim.step() + optim.zero_grad() + scheduler.step_batch() + + # diagnostic.print_diagnostics() + + stop = timeit.default_timer() + logging.info(f"Iter={iter}, Time taken: {stop - start}") + + logging.info(f"last lr = {scheduler.get_last_lr()}") + # logging.info("state dict = ", scheduler.state_dict()) + # logging.info("optim state_dict = ", optim.state_dict()) + logging.info(f"input_magnitudes = {input_magnitudes}") + logging.info(f"output_magnitudes = {output_magnitudes}") + + +if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + logging.getLogger().setLevel(logging.INFO) + import subprocess + + s = subprocess.check_output( + "git status -uno .; git log -1; git diff HEAD .", shell=True + ) + logging.info(s) + import sys + + if len(sys.argv) > 1: + hidden_dim = int(sys.argv[1]) + else: + hidden_dim = 200 + + _test_scaled_adam(hidden_dim) + _test_eden() diff --git a/egs/multi_zh-hans/ASR/zipformer/pretrained.py b/egs/multi_zh-hans/ASR/zipformer/pretrained.py deleted file mode 120000 index 0bd71dde4..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/pretrained.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/pretrained.py b/egs/multi_zh-hans/ASR/zipformer/pretrained.py new file mode 100755 index 000000000..3104b6084 --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/pretrained.py @@ -0,0 +1,381 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +Note: This is a example for librispeech dataset, if you are using different +dataset, you should change the argument values according to your dataset. + +- For non-streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 + +- For streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --causal 1 \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 + +Usage of this script: + +- For non-streaming model: + +(1) greedy search +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --tokens data/lang_bpe_500/tokens.txt \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) modified beam search +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --tokens ./data/lang_bpe_500/tokens.txt \ + --method modified_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) fast beam search +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --tokens ./data/lang_bpe_500/tokens.txt \ + --method fast_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +- For streaming model: + +(1) greedy search +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens ./data/lang_bpe_500/tokens.txt \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) modified beam search +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens ./data/lang_bpe_500/tokens.txt \ + --method modified_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) fast beam search +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens ./data/lang_bpe_500/tokens.txt \ + --method fast_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + + +You can also use `./zipformer/exp/epoch-xx.pt`. + +Note: ./zipformer/exp/pretrained.pt is generated by ./zipformer/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import torch +import torchaudio +from beam_search import ( + fast_beam_search_one_best, + greedy_search_batch, + modified_beam_search, +) +from export import num_tokens +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_model, get_params + +from icefall.utils import make_pad_mask + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + token_table = k2.SymbolTable.from_file(params.tokens) + + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + + logging.info("Creating model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + feature_lengths = torch.tensor(feature_lengths, device=device) + + # model forward + encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths) + + hyps = [] + msg = f"Using {params.method}" + logging.info(msg) + + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) + else: + raise ValueError(f"Unsupported method: {params.method}") + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + s += f"{filename}:\n{hyp}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/multi_zh-hans/ASR/zipformer/scaling.py b/egs/multi_zh-hans/ASR/zipformer/scaling.py deleted file mode 120000 index 6f398f431..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/scaling.py b/egs/multi_zh-hans/ASR/zipformer/scaling.py new file mode 100644 index 000000000..4ee7b7826 --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/scaling.py @@ -0,0 +1,1840 @@ +# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Optional, Tuple, Union +import logging +import k2 +from torch.cuda.amp import custom_fwd, custom_bwd +import random +import torch +import math +import torch.nn as nn +from torch import Tensor + +def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor: + max_value = torch.max(x, y) + diff = torch.abs(x - y) + return max_value + torch.log1p(torch.exp(-diff)) + + +# RuntimeError: Exporting the operator logaddexp to ONNX opset version +# 14 is not supported. Please feel free to request support or submit +# a pull request on PyTorch GitHub. +# +# The following function is to solve the above error when exporting +# models to ONNX via torch.jit.trace() +def logaddexp(x: Tensor, y: Tensor) -> Tensor: + # Caution(fangjun): Put torch.jit.is_scripting() before + # torch.onnx.is_in_onnx_export(); + # otherwise, it will cause errors for torch.jit.script(). + # + # torch.logaddexp() works for both torch.jit.script() and + # torch.jit.trace() but it causes errors for ONNX export. + # + if torch.jit.is_scripting(): + # Note: We cannot use torch.jit.is_tracing() here as it also + # matches torch.onnx.export(). + return torch.logaddexp(x, y) + elif torch.onnx.is_in_onnx_export(): + return logaddexp_onnx(x, y) + else: + # for torch.jit.trace() + return torch.logaddexp(x, y) + +class PiecewiseLinear(object): + """ + Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with + the x values in order. x values <[initial x] or >[final x] are map to [initial y], [final y] + respectively. + """ + def __init__(self, *args): + assert len(args) >= 1, len(args) + if len(args) == 1 and isinstance(args[0], PiecewiseLinear): + self.pairs = list(args[0].pairs) + else: + self.pairs = [ (float(x), float(y)) for x,y in args ] + for (x,y) in self.pairs: + assert isinstance(x, (float, int)), type(x) + assert isinstance(y, (float, int)), type(y) + + for i in range(len(self.pairs) - 1): + assert self.pairs[i + 1][0] > self.pairs[i][0], (i, self.pairs[i], self.pairs[i + 1]) + + def __str__(self): + # e.g. 'PiecewiseLinear((0., 10.), (100., 0.))' + return f'PiecewiseLinear({str(self.pairs)[1:-1]})' + + def __call__(self, x): + if x <= self.pairs[0][0]: + return self.pairs[0][1] + elif x >= self.pairs[-1][0]: + return self.pairs[-1][1] + else: + cur_x, cur_y = self.pairs[0] + for i in range(1, len(self.pairs)): + next_x, next_y = self.pairs[i] + if x >= cur_x and x <= next_x: + return cur_y + (next_y - cur_y) * (x - cur_x) / (next_x - cur_x) + cur_x, cur_y = next_x, next_y + assert False + + def __mul__(self, alpha): + return PiecewiseLinear( + * [(x, y * alpha) for x, y in self.pairs]) + + def __add__(self, x): + if isinstance(x, (float, int)): + return PiecewiseLinear( + * [(p[0], p[1] + x) for p in self.pairs]) + s, x = self.get_common_basis(x) + return PiecewiseLinear( + * [(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)]) + + def max(self, x): + if isinstance(x, (float, int)): + x = PiecewiseLinear( (0, x) ) + s, x = self.get_common_basis(x, include_crossings=True) + return PiecewiseLinear( + * [(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]) + + def min(self, x): + if isinstance(x, float) or isinstance(x, int): + x = PiecewiseLinear( (0, x) ) + s, x = self.get_common_basis(x, include_crossings=True) + return PiecewiseLinear( + * [ (sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]) + + def __eq__(self, other): + return self.pairs == other.pairs + + def get_common_basis(self, + p: 'PiecewiseLinear', + include_crossings: bool = False): + """ + Returns (self_mod, p_mod) which are equivalent piecewise lienar + functions to self and p, but with the same x values. + + p: the other piecewise linear function + include_crossings: if true, include in the x values positions + where the functions indicate by this and p crosss. + """ + assert isinstance(p, PiecewiseLinear), type(p) + + # get sorted x-values without repetition. + x_vals = sorted(set([ x for x, _ in self.pairs ] + [ x for x, _ in p.pairs ])) + y_vals1 = [ self(x) for x in x_vals ] + y_vals2 = [ p(x) for x in x_vals ] + + if include_crossings: + extra_x_vals = [] + for i in range(len(x_vals) - 1): + if (y_vals1[i] > y_vals2[i]) != (y_vals1[i+1] > y_vals2[i+1]): + # if the two lines in this subsegment potentially cross each other.. + diff_cur = abs(y_vals1[i] - y_vals2[i]) + diff_next = abs(y_vals1[i+1] - y_vals2[i+1]) + # `pos`, between 0 and 1, gives the relative x position, + # with 0 being x_vals[i] and 1 being x_vals[i+1]. + pos = diff_cur / (diff_cur + diff_next) + extra_x_val = x_vals[i] + pos * (x_vals[i+1] - x_vals[i]) + extra_x_vals.append(extra_x_val) + if len(extra_x_vals) > 0: + x_vals = sorted(set(x_vals + extra_x_vals)) + y_vals1 = [ self(x) for x in x_vals ] + y_vals2 = [ p(x) for x in x_vals ] + return ( PiecewiseLinear(* zip(x_vals, y_vals1)), + PiecewiseLinear(* zip(x_vals, y_vals2)) ) + + +class ScheduledFloat(torch.nn.Module): + """ + This object is a torch.nn.Module only because we want it to show up in [top_level module].modules(); + it does not have a working forward() function. You are supposed to cast it to float, as + in, float(parent_module.whatever), and use it as something like a dropout prob. + + It is a floating point value whose value changes depending on the batch count of the + training loop. It is a piecewise linear function where you specifiy the (x,y) pairs + in sorted order on x; x corresponds to the batch index. For batch-index values before the + first x or after the last x, we just use the first or last y value. + + Example: + self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0) + + `default` is used when self.batch_count is not set or not in training mode or in + torch.jit scripting mode. + """ + def __init__(self, + *args, + default: float = 0.0): + super().__init__() + # self.batch_count and self.name will be written to in the training loop. + self.batch_count = None + self.name = None + self.default = default + self.schedule = PiecewiseLinear(*args) + + def extra_repr(self) -> str: + return f'batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}' + + def __float__(self): + batch_count = self.batch_count + if batch_count is None or not self.training or torch.jit.is_scripting() or torch.jit.is_tracing(): + return float(self.default) + else: + ans = self.schedule(self.batch_count) + if random.random() < 0.0002: + logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}") + return ans + + def __add__(self, x): + if isinstance(x, float) or isinstance(x, int): + return ScheduledFloat(self.schedule + x, + default=self.default) + else: + return ScheduledFloat(self.schedule + x.schedule, + default=self.default+x.default) + + def max(self, x): + if isinstance(x, float) or isinstance(x, int): + return ScheduledFloat(self.schedule.max(x), + default=self.default) + else: + return ScheduledFloat(self.schedule.max(x.schedule), + default=max(self.default, x.default)) + + +FloatLike = Union[float, ScheduledFloat] + + +def random_cast_to_half(x: Tensor, + min_abs: float = 5.0e-06) -> Tensor: + """ + A randomized way of casting a floating point value to half precision. + """ + if x.dtype == torch.float16: + return x + x_abs = x.abs() + is_too_small = (x_abs < min_abs) + # for elements where is_too_small is true, random_val will contain +-min_abs with + # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations, + # for those elements]. + random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs) + return torch.where(is_too_small, random_val, x).to(torch.float16) + + +class CutoffEstimator: + """ + Estimates cutoffs of an arbitrary numerical quantity such that a specified + proportion of items will be above the cutoff on average. + + p is the proportion of items that should be above the cutoff. + """ + def __init__(self, p: float): + self.p = p + # total count of items + self.count = 0 + # total count of items that were above the cutoff + self.count_above = 0 + # initial cutoff value + self.cutoff = 0 + + def __call__(self, x: float) -> bool: + """ + Returns true if x is above the cutoff. + """ + ans = (x > self.cutoff) + self.count += 1 + if ans: + self.count_above += 1 + cur_p = self.count_above / self.count + delta_p = cur_p - self.p + if (delta_p > 0) == ans: + q = abs(delta_p) + self.cutoff = x * q + self.cutoff * (1-q) + return ans + + +class SoftmaxFunction(torch.autograd.Function): + """ + Tries to handle half-precision derivatives in a randomized way that should + be more accurate for training than the default behavior. + """ + @staticmethod + def forward(ctx, x: Tensor, dim: int): + ans = x.softmax(dim=dim) + # if x dtype is float16, x.softmax() returns a float32 because + # (presumably) that op does not support float16, and autocast + # is enabled. + if torch.is_autocast_enabled(): + ans = ans.to(torch.float16) + ctx.save_for_backward(ans) + ctx.x_dtype = x.dtype + ctx.dim = dim + return ans + + @staticmethod + def backward(ctx, ans_grad: Tensor): + ans, = ctx.saved_tensors + with torch.cuda.amp.autocast(enabled=False): + ans_grad = ans_grad.to(torch.float32) + ans = ans.to(torch.float32) + x_grad = ans_grad * ans + x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) + return x_grad, None + + +def softmax(x: Tensor, dim: int): + if not x.requires_grad or torch.jit.is_scripting() or torch.jit.is_tracing(): + return x.softmax(dim=dim) + + return SoftmaxFunction.apply(x, dim) + + +class MaxEigLimiterFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + coeffs: Tensor, + direction: Tensor, + channel_dim: int, + grad_scale: float) -> Tensor: + ctx.channel_dim = channel_dim + ctx.grad_scale = grad_scale + ctx.save_for_backward(x.detach(), + coeffs.detach(), + direction.detach()) + return x + + @staticmethod + def backward(ctx, x_grad, *args): + with torch.enable_grad(): + (x_orig, coeffs, new_direction) = ctx.saved_tensors + x_orig.requires_grad = True + num_channels = x_orig.shape[ctx.channel_dim] + x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) + new_direction.requires_grad = False + x = x - x.mean(dim=0) + x_var = (x ** 2).mean() + x_residual = x - coeffs * new_direction + x_residual_var = (x_residual ** 2).mean() + # `variance_proportion` is the proportion of the variance accounted for + # by the top eigen-direction. This is to be minimized. + variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) + variance_proportion.backward() + x_orig_grad = x_orig.grad + x_extra_grad = x_orig.grad * ctx.grad_scale * x_grad.norm() / (x_orig_grad.norm() + 1.0e-20) + return x_grad + x_extra_grad.detach(), None, None, None, None + + +class BiasNormFunction(torch.autograd.Function): + # This computes: + # scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp() + # return (x - bias) * scales + # (after unsqueezing the bias), but it does it in a memory-efficient way so that + # it can just store the returned value (chances are, this will also be needed for + # some other reason, related to the next operation, so we can save memory). + @staticmethod + def forward(ctx, x: Tensor, bias: Tensor, log_scale: Tensor, channel_dim: int, + store_output_for_backprop: bool) -> Tensor: + assert bias.ndim == 1 + if channel_dim < 0: + channel_dim = channel_dim + x.ndim + ctx.store_output_for_backprop = store_output_for_backprop + ctx.channel_dim = channel_dim + for _ in range(channel_dim + 1, x.ndim): + bias = bias.unsqueeze(-1) + scales = (torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5) * log_scale.exp() + ans = x * scales + ctx.save_for_backward(ans.detach() if store_output_for_backprop else x, + scales.detach(), bias.detach(), log_scale.detach()) + return ans + + @staticmethod + def backward(ctx, ans_grad: Tensor) -> Tensor: + ans_or_x, scales, bias, log_scale = ctx.saved_tensors + if ctx.store_output_for_backprop: + x = ans_or_x / scales + else: + x = ans_or_x + x = x.detach() + x.requires_grad = True + bias.requires_grad = True + log_scale.requires_grad = True + with torch.enable_grad(): + # recompute scales from x, bias and log_scale. + scales = (torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5) * log_scale.exp() + ans = x * scales + ans.backward(gradient=ans_grad) + return x.grad, bias.grad.flatten(), log_scale.grad, None, None + + +class BiasNorm(torch.nn.Module): + """ + This is intended to be a simpler, and hopefully cheaper, replacement for + LayerNorm. The observation this is based on, is that Transformer-type + networks, especially with pre-norm, sometimes seem to set one of the + feature dimensions to a large constant value (e.g. 50), which "defeats" + the LayerNorm because the output magnitude is then not strongly dependent + on the other (useful) features. Presumably the weight and bias of the + LayerNorm are required to allow it to do this. + + Instead, we give the BiasNorm a trainable bias that it can use when + computing the scale for normalization. We also give it a (scalar) + trainable scale on the output. + + + Args: + num_channels: the number of channels, e.g. 512. + channel_dim: the axis/dimension corresponding to the channel, + interprted as an offset from the input's ndim if negative. + shis is NOT the num_channels; it should typically be one of + {-2, -1, 0, 1, 2, 3}. + log_scale: the initial log-scale that we multiply the output by; this + is learnable. + log_scale_min: FloatLike, minimum allowed value of log_scale + log_scale_max: FloatLike, maximum allowed value of log_scale + store_output_for_backprop: only possibly affects memory use; recommend + to set to True if you think the output of this module is more likely + than the input of this module to be required to be stored for the + backprop. + """ + def __init__( + self, + num_channels: int, + channel_dim: int = -1, # CAUTION: see documentation. + log_scale: float = 1.0, + log_scale_min: float = -1.5, + log_scale_max: float = 1.5, + store_output_for_backprop: bool = False + ) -> None: + super(BiasNorm, self).__init__() + self.num_channels = num_channels + self.channel_dim = channel_dim + self.log_scale = nn.Parameter(torch.tensor(log_scale)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + + self.log_scale_min = log_scale_min + self.log_scale_max = log_scale_max + + self.store_output_for_backprop = store_output_for_backprop + + def forward(self, x: Tensor) -> Tensor: + assert x.shape[self.channel_dim] == self.num_channels + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + channel_dim = self.channel_dim + if channel_dim < 0: + channel_dim += x.ndim + bias = self.bias + for _ in range(channel_dim + 1, x.ndim): + bias = bias.unsqueeze(-1) + scales = ((torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5) * + self.log_scale.exp()) + return x * scales + + log_scale = limit_param_value(self.log_scale, + min=float(self.log_scale_min), + max=float(self.log_scale_max), + training=self.training) + + return BiasNormFunction.apply(x, self.bias, log_scale, + self.channel_dim, + self.store_output_for_backprop) + + +def ScaledLinear(*args, + initial_scale: float = 1.0, + **kwargs) -> nn.Linear: + """ + Behaves like a constructor of a modified version of nn.Linear + that gives an easy way to set the default initial parameter scale. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + ans = nn.Linear(*args, **kwargs) + with torch.no_grad(): + ans.weight[:] *= initial_scale + if ans.bias is not None: + torch.nn.init.uniform_(ans.bias, + -0.1 * initial_scale, + 0.1 * initial_scale) + return ans + + +def ScaledConv1d(*args, + initial_scale: float = 1.0, + **kwargs) -> nn.Conv1d: + """ + Behaves like a constructor of a modified version of nn.Conv1d + that gives an easy way to set the default initial parameter scale. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + ans = nn.Conv1d(*args, **kwargs) + with torch.no_grad(): + ans.weight[:] *= initial_scale + if ans.bias is not None: + torch.nn.init.uniform_(ans.bias, + -0.1 * initial_scale, + 0.1 * initial_scale) + return ans + + +def ScaledConv2d(*args, + initial_scale: float = 1.0, + **kwargs) -> nn.Conv2d: + """ + Behaves like a constructor of a modified version of nn.Conv2d + that gives an easy way to set the default initial parameter scale. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False, but: + NO PADDING-RELATED ARGS. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + ans = nn.Conv2d(*args, **kwargs) + with torch.no_grad(): + ans.weight[:] *= initial_scale + if ans.bias is not None: + torch.nn.init.uniform_(ans.bias, + -0.1 * initial_scale, + 0.1 * initial_scale) + return ans + + +class ChunkCausalDepthwiseConv1d(torch.nn.Module): + """ + Behaves like a depthwise 1d convolution, except that it is causal in + a chunkwise way, as if we had a block-triangular attention mask. + The chunk size is provided at test time (it should probably be + kept in sync with the attention mask). + + This has a little more than twice the parameters of a conventional + depthwise conv1d module: we implement it by having one + depthwise convolution, of half the width, that is causal (via + right-padding); and one depthwise convolution that is applied only + within chunks, that we multiply by a scaling factor which depends + on the position within the chunk. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + def __init__(self, + channels: int, + kernel_size: int, + initial_scale: float = 1.0, + bias: bool = True): + super().__init__() + assert kernel_size % 2 == 1 + + half_kernel_size = (kernel_size + 1) // 2 + # will pad manually, on one side. + self.causal_conv = nn.Conv1d(in_channels=channels, + out_channels=channels, + groups=channels, + kernel_size=half_kernel_size, + padding=0, + bias=True) + + self.chunkwise_conv = nn.Conv1d(in_channels=channels, + out_channels=channels, + groups=channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + bias=bias) + + # first row is correction factors added to the scale near the left edge of the chunk, + # second row is correction factors added to the scale near the right edge of the chunk, + # both of these are added to a default scale of 1.0. + self.chunkwise_conv_scale = nn.Parameter(torch.zeros(2, channels, kernel_size)) + self.kernel_size = kernel_size + + with torch.no_grad(): + self.causal_conv.weight[:] *= initial_scale + self.chunkwise_conv.weight[:] *= initial_scale + if bias: + torch.nn.init.uniform_(self.causal_conv.bias, + -0.1 * initial_scale, + 0.1 * initial_scale) + + def forward(self, + x: Tensor, + chunk_size: int = -1) -> Tensor: + """ + Forward function. Args: + x: a Tensor of shape (batch_size, channels, seq_len) + chunk_size: the chunk size, in frames; does not have to divide seq_len exactly. + """ + (batch_size, num_channels, seq_len) = x.shape + + # half_kernel_size = self.kernel_size + 1 // 2 + # left_pad is half_kernel_size - 1 where half_kernel_size is the size used + # in the causal conv. It's the amount by which we must pad on the left, + # to make the convolution causal. + left_pad = self.kernel_size // 2 + + if chunk_size < 0 or chunk_size > seq_len: + chunk_size = seq_len + right_pad = -seq_len % chunk_size + + x = torch.nn.functional.pad(x, (left_pad, right_pad)) + + x_causal = self.causal_conv(x[..., :left_pad + seq_len]) + assert x_causal.shape == (batch_size, num_channels, seq_len) + + x_chunk = x[..., left_pad:] + num_chunks = x_chunk.shape[2] // chunk_size + x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks, chunk_size) + x_chunk = x_chunk.permute(0, 2, 1, 3).reshape(batch_size * num_chunks, + num_channels, chunk_size) + x_chunk = self.chunkwise_conv(x_chunk) # does not change shape + + chunk_scale = self._get_chunk_scale(chunk_size) + + x_chunk = x_chunk * chunk_scale + x_chunk = x_chunk.reshape(batch_size, num_chunks, + num_channels, chunk_size).permute(0, 2, 1, 3) + x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks * chunk_size)[..., :seq_len] + + return x_chunk + x_causal + + def _get_chunk_scale(self, chunk_size: int): + """Returns tensor of shape (num_channels, chunk_size) that will be used to + scale the output of self.chunkwise_conv.""" + left_edge = self.chunkwise_conv_scale[0] + right_edge = self.chunkwise_conv_scale[1] + if chunk_size < self.kernel_size: + left_edge = left_edge[:, :chunk_size] + right_edge = right_edge[:, -chunk_size:] + else: + t = chunk_size - self.kernel_size + channels = left_edge.shape[0] + pad = torch.zeros(channels, t, + device=left_edge.device, + dtype=left_edge.dtype) + left_edge = torch.cat((left_edge, pad), dim=-1) + right_edge = torch.cat((pad, right_edge), dim=-1) + return 1.0 + (left_edge + right_edge) + + def streaming_forward( + self, + x: Tensor, + cache: Tensor, + ) -> Tuple[Tensor, Tensor]: + """Streaming Forward function. + + Args: + x: a Tensor of shape (batch_size, channels, seq_len) + cache: cached left context of shape (batch_size, channels, left_pad) + """ + (batch_size, num_channels, seq_len) = x.shape + + # left_pad is half_kernel_size - 1 where half_kernel_size is the size used + # in the causal conv. It's the amount by which we must pad on the left, + # to make the convolution causal. + left_pad = self.kernel_size // 2 + + # Pad cache + assert cache.shape[-1] == left_pad, (cache.shape[-1], left_pad) + x = torch.cat([cache, x], dim=2) + # Update cache + cache = x[..., -left_pad:] + + x_causal = self.causal_conv(x) + assert x_causal.shape == (batch_size, num_channels, seq_len) + + x_chunk = x[..., left_pad:] + x_chunk = self.chunkwise_conv(x_chunk) # does not change shape + + chunk_scale = self._get_chunk_scale(chunk_size=seq_len) + x_chunk = x_chunk * chunk_scale + + return x_chunk + x_causal, cache + + +class BalancerFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + min_mean: float, + max_mean: float, + min_rms: float, + max_rms: float, + grad_scale: float, + channel_dim: int, + ) -> Tensor: + if channel_dim < 0: + channel_dim += x.ndim + ctx.channel_dim = channel_dim + ctx.save_for_backward(x) + ctx.config = (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) + return x + + @staticmethod + def backward( + ctx, x_grad: Tensor + ) -> Tuple[Tensor, None, None, None, None, None]: + x, = ctx.saved_tensors + (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config + + try: + with torch.enable_grad(): + with torch.cuda.amp.autocast(enabled=False): + x = x.to(torch.float32) + x = x.detach() + x.requires_grad = True + mean_dims = [ i for i in range(x.ndim) if i != channel_dim ] + uncentered_var = (x ** 2).mean(dim=mean_dims, keepdim=True) + mean = x.mean(dim=mean_dims, keepdim=True) + stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt() + rms = uncentered_var.clamp(min=1.0e-20).sqrt() + + m = mean / stddev + # part of loss that relates to mean / stddev + m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs() + + # put a much larger scale on the RMS-max-limit loss, so that if both it and the + # m_loss are violated we fix the RMS loss first. + rms_clamped = rms.clamp(min=min_rms, max=max_rms) + r_loss = (rms_clamped / rms).log().abs() + + loss = (m_loss + r_loss) + + loss.backward(gradient=torch.ones_like(loss)) + loss_grad = x.grad + loss_grad_rms = (loss_grad ** 2).mean(dim=mean_dims, keepdim=True).sqrt().clamp(min=1.0e-20) + + loss_grad = loss_grad * (grad_scale / loss_grad_rms) + + x_grad_float = x_grad.to(torch.float32) + # scale each element of loss_grad by the absolute value of the corresponding + # element of x_grad, which we view as a noisy estimate of its magnitude for that + # (frame and dimension). later we can consider factored versions. + x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad) + x_grad = x_grad_mod.to(x_grad.dtype) + except Exception as e: + logging.info(f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue.") + + return x_grad, None, None, None, None, None, None + + +class Balancer(torch.nn.Module): + """ + Modifies the backpropped derivatives of a function to try to encourage, for + each channel, that it is positive at least a proportion `threshold` of the + time. It does this by multiplying negative derivative values by up to + (1+max_factor), and positive derivative values by up to (1-max_factor), + interpolated from 1 at the threshold to those extremal values when none + of the inputs are positive. + + Args: + num_channels: the number of channels + channel_dim: the dimension/axis corresponding to the channel, e.g. + -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. + min_positive: the minimum, per channel, of the proportion of the time + that (x > 0), below which we start to modify the derivatives. + max_positive: the maximum, per channel, of the proportion of the time + that (x > 0), above which we start to modify the derivatives. + scale_gain_factor: determines the 'gain' with which we increase the + change in gradient once the constraints on min_abs and max_abs + are violated. + min_abs: the minimum average-absolute-value difference from the mean + value per channel, which we allow, before we start to modify + the derivatives to prevent this. + max_abs: the maximum average-absolute-value difference from the mean + value per channel, which we allow, before we start to modify + the derivatives to prevent this. + prob: determines the minimum probability with which we modify the + gradients for the {min,max}_positive and {min,max}_abs constraints, + on each forward(). This is done randomly to prevent all layers + from doing it at the same time. + """ + def __init__( + self, + num_channels: int, + channel_dim: int, + min_positive: FloatLike = 0.05, + max_positive: FloatLike = 0.95, + min_abs: FloatLike = 0.2, + max_abs: FloatLike = 100.0, + grad_scale: FloatLike = 0.04, + prob: Optional[FloatLike] = None, + ): + super().__init__() + + if prob is None: + prob = ScheduledFloat((0.0, 0.5), (8000.0, 0.125), default=0.4) + self.prob = prob + # 5% of the time we will return and do nothing because memory usage is + # too high. + self.mem_cutoff = CutoffEstimator(0.05) + + # actually self.num_channels is no longer needed except for an assertion. + self.num_channels = num_channels + self.channel_dim = channel_dim + self.min_positive = min_positive + self.max_positive = max_positive + self.min_abs = min_abs + self.max_abs = max_abs + self.grad_scale = grad_scale + + def forward(self, x: Tensor) -> Tensor: + if (torch.jit.is_scripting() or not x.requires_grad or + (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))): + return _no_op(x) + + prob = float(self.prob) + if random.random() < prob: + # The following inner-functions convert from the way we historically specified + # these limitations, as limits on the absolute value and the proportion of positive + # values, to limits on the RMS value and the (mean / stddev). + def _abs_to_rms(x): + # for normally distributed data, if the expected absolute value is x, the + # expected rms value will be sqrt(pi/2) * x. + return 1.25331413732 * x + + def _proportion_positive_to_mean(x): + def _atanh(x): + eps = 1.0e-10 + # eps is to prevent crashes if x is exactly 0 or 1. + # we'll just end up returning a fairly large value. + return (math.log (1+x+eps) - math.log (1-x+eps)) / 2. + + def _approx_inverse_erf(x): + # 1 / (sqrt(pi) * ln(2)), + # see https://math.stackexchange.com/questions/321569/approximating-the-error-function-erf-by-analytical-functions + # this approximation is extremely crude and gets progressively worse for + # x very close to -1 or +1, but we mostly care about the "middle" region + # e.g. _approx_inverse_erf(0.05) = 0.0407316414078772, + # and math.erf(0.0407316414078772) = 0.045935330944660666, + # which is pretty close to 0.05. + return 0.8139535143 * _atanh(x) + # first convert x from the range 0..1 to the range -1..1 which the error + # function returns + x = -1 + (2 * x) + return _approx_inverse_erf(x) + + min_mean = _proportion_positive_to_mean(float(self.min_positive)) + max_mean = _proportion_positive_to_mean(float(self.max_positive)) + min_rms = _abs_to_rms(float(self.min_abs)) + max_rms = _abs_to_rms(float(self.max_abs)) + grad_scale = float(self.grad_scale) + + assert x.shape[self.channel_dim] == self.num_channels + + return BalancerFunction.apply( + x, min_mean, max_mean, min_rms, max_rms, grad_scale, self.channel_dim + ) + else: + return _no_op(x) + + +def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float, + name: str = None) -> Tensor: + """ + Returns x unmodified, but in backprop will put a penalty for the excess of + the absolute values of elements of x over the limit "limit". E.g. if + limit == 10.0, then if x has any values over 10 it will get a penalty. + + Caution: the value of this penalty will be affected by grad scaling used + in automatic mixed precision training. For this reasons we use this, + it shouldn't really matter, or may even be helpful; we just use this + to disallow really implausible values of scores to be given to softmax. + + The name is for randomly printed debug info. + """ + x_sign = x.sign() + over_limit = (x.abs() - limit) > 0 + # The following is a memory efficient way to penalize the absolute values of + # x that's over the limit. (The memory efficiency comes when you think + # about which items torch needs to cache for the autograd, and which ones it + # can throw away). The numerical value of aux_loss as computed here will + # actually be larger than it should be, by limit * over_limit.sum(), but it + # has the same derivative as the real aux_loss which is penalty * (x.abs() - + # limit).relu(). + aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x) + # note: we don't do sum() here on aux)_loss, but it's as if we had done + # sum() due to how with_loss() works. + x = with_loss(x, aux_loss, name) + # you must use x for something, or this will be ineffective. + return x + + +def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. + if x.ndim == 2: + return x.diag() + else: + (batch, dim, dim) = x.shape + x = x.reshape(batch, dim * dim) + x = x[:, ::dim+1] + assert x.shape == (batch, dim) + return x + + +def _whitening_metric(x: Tensor, + num_groups: int): + """ + Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of + of the centered feature covariance are the same within each group's covariance matrix + and also between groups. + Args: + x: a Tensor of shape (*, num_channels) + num_groups: the number of groups of channels, a number >=1 that divides num_channels + Returns: + Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and + greater than 1.0 otherwise. + """ + assert x.dtype != torch.float16 + x = x.reshape(-1, x.shape[-1]) + (num_frames, num_channels) = x.shape + assert num_channels % num_groups == 0 + channels_per_group = num_channels // num_groups + x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1) + # x now has shape (num_groups, num_frames, channels_per_group) + # subtract the mean so we use the centered, not uncentered, covariance. + # My experience has been that when we "mess with the gradients" like this, + # it's better not do anything that tries to move the mean around, because + # that can easily cause instability. + x = x - x.mean(dim=1, keepdim=True) + # x_covar: (num_groups, channels_per_group, channels_per_group) + x_covar = torch.matmul(x.transpose(1, 2), x) + x_covar_mean_diag = _diag(x_covar).mean() + # the following expression is what we'd get if we took the matrix product + # of each covariance and measured the mean of its trace, i.e. + # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). + x_covarsq_mean_diag = (x_covar ** 2).sum() / (num_groups * channels_per_group) + # this metric will be >= 1.0; the larger it is, the less 'white' the data was. + metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20) + return metric + + +class WhiteningPenaltyFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, + x: Tensor, + module: nn.Module) -> Tensor: + ctx.save_for_backward(x) + ctx.module = module + return x + + @staticmethod + def backward(ctx, + x_grad: Tensor): + x_orig, = ctx.saved_tensors + w = ctx.module + + try: + with torch.enable_grad(): + with torch.cuda.amp.autocast(enabled=False): + x_detached = x_orig.to(torch.float32).detach() + x_detached.requires_grad = True + + metric = _whitening_metric(x_detached, w.num_groups) + + if random.random() < 0.005 or __name__ == "__main__": + logging.info(f"Whitening: name={w.name}, num_groups={w.num_groups}, num_channels={x_orig.shape[-1]}, " + f"metric={metric.item():.2f} vs. limit={float(w.whitening_limit)}") + + if metric < float(w.whitening_limit): + w.prob = w.min_prob + return x_grad, None + else: + w.prob = w.max_prob + metric.backward() + penalty_grad = x_detached.grad + scale = w.grad_scale * (x_grad.to(torch.float32).norm() / + (penalty_grad.norm() + 1.0e-20)) + penalty_grad = penalty_grad * scale + return x_grad + penalty_grad.to(x_grad.dtype), None + except Exception as e: + logging.info(f"Caught exception in Whiten backward: {e}, size={list(x_grad.shape)}, will continue.") + return x_grad, None + + +class Whiten(nn.Module): + def __init__( + self, + num_groups: int, + whitening_limit: FloatLike, + prob: Union[float, Tuple[float,float]], + grad_scale: FloatLike): + """ + Args: + num_groups: the number of groups to divide the channel dim into before + whitening. We will attempt to make the feature covariance + within each group, after mean subtraction, as "white" as possible, + while having the same trace across all groups. + whitening_limit: a value greater than 1.0, that dictates how much + freedom we have to violate the constraints. 1.0 would mean perfectly + white, with exactly the same trace across groups; larger values + give more freedom. E.g. 2.0. + prob: the probability with which we apply the gradient modification + (also affects the grad scale). May be supplied as a float, + or as a pair (min_prob, max_prob) + + grad_scale: determines the scale on the gradient term from this object, + relative to the rest of the gradient on the attention weights. + E.g. 0.02 (you may want to use smaller values than this if prob is large) + """ + super(Whiten, self).__init__() + assert num_groups >= 1 + assert float(whitening_limit) >= 1 + assert grad_scale >= 0 + self.num_groups = num_groups + self.whitening_limit = whitening_limit + self.grad_scale = grad_scale + + if isinstance(prob, float): + prob = (prob, prob) + (self.min_prob, self.max_prob) = prob + assert 0 < self.min_prob <= self.max_prob <= 1 + self.prob = self.max_prob + self.name = None # will be set in training loop + + def forward(self, + x: Tensor) -> Tensor: + """ + In the forward pass, this function just returns the input unmodified. + In the backward pass, it will modify the gradients to ensure that the + distribution in each group has close to (lambda times I) as the covariance + after mean subtraction, with the same lambda across groups. + For whitening_limit > 1, there will be more freedom to violate this + constraint. + + Args: + x: the input of shape (*, num_channels) + + Returns: + x, unmodified. You should make sure + you use the returned value, or the graph will be freed + and nothing will happen in backprop. + """ + grad_scale = float(self.grad_scale) + if not x.requires_grad or random.random() > self.prob or grad_scale == 0: + return _no_op(x) + else: + return WhiteningPenaltyFunction.apply(x, self) + + +class WithLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, y: Tensor, name: str): + ctx.y_shape = y.shape + if random.random() < 0.002 and name is not None: + loss_sum = y.sum().item() + logging.info(f"WithLoss: name={name}, loss-sum={loss_sum:.3e}") + return x + + @staticmethod + def backward(ctx, ans_grad: Tensor): + return ans_grad, torch.ones(ctx.y_shape, + dtype=ans_grad.dtype, + device=ans_grad.device), None + + +def with_loss(x, y, name): + # returns x but adds y.sum() to the loss function. + return WithLoss.apply(x, y, name) + + +class ScaleGradFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, alpha: float) -> Tensor: + ctx.alpha = alpha + return x + + @staticmethod + def backward(ctx, grad: Tensor): + return grad * ctx.alpha, None + + +def scale_grad(x: Tensor, alpha: float): + return ScaleGradFunction.apply(x, alpha) + + +class ScaleGrad(nn.Module): + def __init__(self, alpha: float): + super().__init__() + self.alpha = alpha + + def forward(self, x: Tensor) -> Tensor: + if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: + return x + return scale_grad(x, self.alpha) + + +class LimitParamValue(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, min: float, max: float): + ctx.save_for_backward(x) + assert max >= min + ctx.min = min + ctx.max = max + return x + + @staticmethod + def backward(ctx, x_grad: Tensor): + x, = ctx.saved_tensors + # where x < ctx.min, ensure all grads are negative (this will tend to make + # x more positive). + x_grad = x_grad * torch.where(torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0) + # where x > ctx.max, ensure all grads are positive (this will tend to make + # x more negative). + x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0) + return x_grad, None, None + + +def limit_param_value(x: Tensor, + min: float, max: float, + prob: float = 0.6, + training: bool = True): + # You apply this to (typically) an nn.Parameter during training to ensure that its + # (elements mostly) stays within a supplied range. This is done by modifying the + # gradients in backprop. + # It's not necessary to do this on every batch: do it only some of the time, + # to save a little time. + if training and random.random() < prob: + return LimitParamValue.apply(x, min, max) + else: + return x + + +def _no_op(x: Tensor) -> Tensor: + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return x + else: + # a no-op function that will have a node in the autograd graph, + # to avoid certain bugs relating to backward hooks + return x.chunk(1, dim=-1)[0] + + +class Identity(torch.nn.Module): + def __init__(self): + super(Identity, self).__init__() + + def forward(self, x): + return _no_op(x) + + +class DoubleSwishFunction(torch.autograd.Function): + """ + double_swish(x) = x * torch.sigmoid(x-1) + + This is a definition, originally motivated by its close numerical + similarity to swish(swish(x)), where swish(x) = x * sigmoid(x). + + Memory-efficient derivative computation: + double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) + double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). + Now, s'(x) = s(x) * (1-s(x)). + double_swish'(x) = x * s'(x) + s(x). + = x * s(x) * (1-s(x)) + s(x). + = double_swish(x) * (1-s(x)) + s(x) + ... so we just need to remember s(x) but not x itself. + """ + + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + requires_grad = x.requires_grad + if x.dtype == torch.float16: + x = x.to(torch.float32) + + s = torch.sigmoid(x - 1.0) + y = x * s + + if requires_grad: + deriv = (y * (1 - s) + s) + + # notes on derivative of x * sigmoid(x - 1): + # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 + # min \simeq -0.043638. Take floor as -0.044 so it's a lower bund + # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound. + # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which + # floors), should be expectation-preserving. + floor = -0.044 + ceil = 1.2 + d_scaled = ((deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv)) + if __name__ == "__main__": + # for self-testing only. + assert d_scaled.min() >= 0.0 + assert d_scaled.max() < 256.0 + d_int = d_scaled.to(torch.uint8) + ctx.save_for_backward(d_int) + if x.dtype == torch.float16 or torch.is_autocast_enabled(): + y = y.to(torch.float16) + return y + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + d, = ctx.saved_tensors + # the same constants as used in forward pass. + floor = -0.043637 + ceil = 1.2 + + d = (d * ((ceil - floor) / 255.0) + floor) + return y_grad * d + + +class DoubleSwish(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: Tensor) -> Tensor: + """Return double-swish activation function which is an approximation to Swish(Swish(x)), + that we approximate closely with x * sigmoid(x-1). + """ + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return x * torch.sigmoid(x - 1.0) + return DoubleSwishFunction.apply(x) + + +# Dropout2 is just like normal dropout, except it supports schedules on the dropout rates. +class Dropout2(nn.Module): + def __init__(self, p: FloatLike): + super().__init__() + self.p = p + + def forward(self, x: Tensor) -> Tensor: + return torch.nn.functional.dropout(x, + p=float(self.p), + training=self.training) + + +class MulForDropout3(torch.autograd.Function): + # returns (x * y * alpha) where alpha is a float and y doesn't require + # grad and is zero-or-one. + @staticmethod + @custom_fwd + def forward(ctx, x, y, alpha): + assert not y.requires_grad + ans = x * y * alpha + ctx.save_for_backward(ans) + ctx.alpha = alpha + return ans + + @staticmethod + @custom_bwd + def backward(ctx, ans_grad): + ans, = ctx.saved_tensors + x_grad = ctx.alpha * ans_grad * (ans != 0) + return x_grad, None, None + + +# Dropout3 is just like normal dropout, except it supports schedules on the dropout rates, +# and it lets you choose one dimension to share the dropout mask over +class Dropout3(nn.Module): + def __init__(self, p: FloatLike, shared_dim: int): + super().__init__() + self.p = p + self.shared_dim = shared_dim + + def forward(self, x: Tensor) -> Tensor: + p = float(self.p) + if not self.training or p == 0: + return _no_op(x) + scale = 1.0 / (1 - p) + rand_shape = list(x.shape) + rand_shape[self.shared_dim] = 1 + mask = torch.rand(*rand_shape, device=x.device) > p + ans = MulForDropout3.apply(x, mask, scale) + return ans + + +class SwooshLFunction(torch.autograd.Function): + """ + swoosh(x) = log(1 + exp(x-4)) - 0.08*x - 0.035 + """ + + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + requires_grad = x.requires_grad + if x.dtype == torch.float16: + x = x.to(torch.float32) + + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + + coeff = -0.08 + + with torch.cuda.amp.autocast(enabled=False): + with torch.enable_grad(): + x = x.detach() + x.requires_grad = True + y = torch.logaddexp(zero, x - 4.0) + coeff * x - 0.035 + + if not requires_grad: + return y + + y.backward(gradient = torch.ones_like(y)) + + grad = x.grad + floor = coeff + ceil = 1.0 + coeff + 0.005 + + d_scaled = ((grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(grad)) + if __name__ == "__main__": + # for self-testing only. + assert d_scaled.min() >= 0.0 + assert d_scaled.max() < 256.0 + + d_int = d_scaled.to(torch.uint8) + ctx.save_for_backward(d_int) + if x.dtype == torch.float16 or torch.is_autocast_enabled(): + y = y.to(torch.float16) + return y + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + d, = ctx.saved_tensors + # the same constants as used in forward pass. + + coeff = -0.08 + floor = coeff + ceil = 1.0 + coeff + 0.005 + d = (d * ((ceil - floor) / 255.0) + floor) + return (y_grad * d) + + +class SwooshL(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return Swoosh-L activation. + """ + if torch.jit.is_scripting() or torch.jit.is_tracing(): + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035 + if not x.requires_grad: + return k2.swoosh_l_forward(x) + else: + return k2.swoosh_l(x) + # return SwooshLFunction.apply(x) + +class SwooshLOnnx(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return Swoosh-L activation. + """ + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + return logaddexp_onnx(zero, x - 4.0) - 0.08 * x - 0.035 + + +class SwooshRFunction(torch.autograd.Function): + """ + swoosh(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687 + + derivatives are between -0.08 and 0.92. + """ + + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + requires_grad = x.requires_grad + + if x.dtype == torch.float16: + x = x.to(torch.float32) + + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + + with torch.cuda.amp.autocast(enabled=False): + with torch.enable_grad(): + x = x.detach() + x.requires_grad = True + y = torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687 + + if not requires_grad: + return y + y.backward(gradient = torch.ones_like(y)) + + grad = x.grad + floor = -0.08 + ceil = 0.925 + + d_scaled = ((grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(grad)) + if __name__ == "__main__": + # for self-testing only. + assert d_scaled.min() >= 0.0 + assert d_scaled.max() < 256.0 + + d_int = d_scaled.to(torch.uint8) + ctx.save_for_backward(d_int) + if x.dtype == torch.float16 or torch.is_autocast_enabled(): + y = y.to(torch.float16) + return y + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + d, = ctx.saved_tensors + # the same constants as used in forward pass. + floor = -0.08 + ceil = 0.925 + d = (d * ((ceil - floor) / 255.0) + floor) + return (y_grad * d) + + +class SwooshR(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return Swoosh-R activation. + """ + if torch.jit.is_scripting() or torch.jit.is_tracing(): + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + return logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687 + if not x.requires_grad: + return k2.swoosh_r_forward(x) + else: + return k2.swoosh_r(x) + # return SwooshRFunction.apply(x) + +class SwooshROnnx(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return Swoosh-R activation. + """ + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + return logaddexp_onnx(zero, x - 1.) - 0.08 * x - 0.313261687 + + +# simple version of SwooshL that does not redefine the backprop, used in +# ActivationDropoutAndLinearFunction. +def SwooshLForward(x: Tensor): + x_offset = x - 4.0 + log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) + log_sum = torch.where(log_sum == float('inf'), x_offset, log_sum) + return log_sum - 0.08 * x - 0.035 + + +# simple version of SwooshR that does not redefine the backprop, used in +# ActivationDropoutAndLinearFunction. +def SwooshRForward(x: Tensor): + x_offset = x - 1.0 + log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) + log_sum = torch.where(log_sum == float('inf'), x_offset, log_sum) + return log_sum - 0.08 * x - 0.313261687 + + +class ActivationDropoutAndLinearFunction(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward(ctx, + x: Tensor, + weight: Tensor, + bias: Optional[Tensor], + activation: str, + dropout_p: float, + dropout_shared_dim: Optional[int]): + if dropout_p != 0.0: + dropout_shape = list(x.shape) + if dropout_shared_dim is not None: + dropout_shape[dropout_shared_dim] = 1 + # else it won't be very memory efficient. + dropout_mask = ((1.0 / (1.0 - dropout_p)) * + (torch.rand(*dropout_shape, + device=x.device, dtype=x.dtype) > dropout_p)) + else: + dropout_mask = None + + ctx.save_for_backward(x, weight, bias, dropout_mask) + + ctx.activation = activation + + forward_activation_dict = { + 'SwooshL': k2.swoosh_l_forward, + 'SwooshR': k2.swoosh_r_forward + } + # it will raise a KeyError if this fails. This will be an error. We let it + # propagate to the user. + activation_func = forward_activation_dict[activation] + x = activation_func(x) + if dropout_mask is not None: + x = x * dropout_mask + x = torch.nn.functional.linear(x, weight, bias) + return x + + @staticmethod + @custom_bwd + def backward(ctx, ans_grad: Tensor): + saved = ctx.saved_tensors + (x, weight, bias, dropout_mask) = saved + + forward_and_deriv_activation_dict = { + 'SwooshL': k2.swoosh_l_forward_and_deriv, + 'SwooshR': k2.swoosh_r_forward_and_deriv + } + # the following lines a KeyError if the activation is unrecognized. + # This will be an error. We let it propagate to the user. + func = forward_and_deriv_activation_dict[ctx.activation] + + y, func_deriv = func(x) + if dropout_mask is not None: + y = y * dropout_mask + # now compute derivative of y w.r.t. weight and bias.. + # y: (..., in_channels), ans_grad: (..., out_channels), + (out_channels, in_channels) = weight.shape + + in_channels = y.shape[-1] + g = ans_grad.reshape(-1, out_channels) + weight_deriv = torch.matmul(g.t(), + y.reshape(-1, in_channels)) + y_deriv = torch.matmul(ans_grad, weight) + bias_deriv = None if bias is None else g.sum(dim=0) + x_deriv = y_deriv * func_deriv + if dropout_mask is not None: + # order versus func_deriv does not matter + x_deriv = x_deriv * dropout_mask + + return x_deriv, weight_deriv, bias_deriv, None, None, None + + +class ActivationDropoutAndLinear(torch.nn.Module): + """ + This merges an activation function followed by dropout and then a nn.Linear module; + it does so in a memory efficient way so that it only stores the input to the whole + module. If activation == SwooshL and dropout_shared_dim != None, this will be + equivalent to: + nn.Sequential(SwooshL(), + Dropout3(dropout_p, shared_dim=dropout_shared_dim), + ScaledLinear(in_channels, out_channels, bias=bias, + initial_scale=initial_scale)) + If dropout_shared_dim is None, the dropout would be equivalent to + Dropout2(dropout_p). Note: Dropout3 will be more memory efficient as the dropout + mask is smaller. + + Args: + in_channels: number of input channels, e.g. 256 + out_channels: number of output channels, e.g. 256 + bias: if true, have a bias + activation: the activation function, for now just support SwooshL. + dropout_p: the dropout probability or schedule (happens after nonlinearity). + dropout_shared_dim: the dimension, if any, across which the dropout mask is + shared (e.g. the time dimension). If None, this may be less memory + efficient if there are modules before this one that cache the input + for their backprop (e.g. Balancer or Whiten). + """ + def __init__(self, + in_channels: int, + out_channels: int, + bias: bool = True, + activation: str = 'SwooshL', + dropout_p: FloatLike = 0.0, + dropout_shared_dim: Optional[int] = -1, + initial_scale: float = 1.0): + super().__init__() + # create a temporary module of nn.Linear that we'll steal the + # weights and bias from + l = ScaledLinear(in_channels, out_channels, + bias=bias, + initial_scale=initial_scale) + + self.weight = l.weight + # register_parameter properly handles making it a parameter when l.bias + # is None. I think there is some reason for doing it this way rather + # than just setting it to None but I don't know what it is, maybe + # something to do with exporting the module.. + self.register_parameter('bias', l.bias) + + self.activation = activation + self.dropout_p = dropout_p + self.dropout_shared_dim = dropout_shared_dim + + def forward(self, + x: Tensor): + if torch.jit.is_scripting() or torch.jit.is_tracing(): + if self.activation == 'SwooshL': + x = SwooshLForward(x) + elif self.activation == "SwooshR": + x = SwooshRForward(x) + else: + assert False, self.activation + return torch.nn.functional.linear(x, + self.weight, + self.bias) + + return ActivationDropoutAndLinearFunction.apply( + x, self.weight, self.bias, self.activation, + float(self.dropout_p), self.dropout_shared_dim) + + +def convert_num_channels(x: Tensor, num_channels: int) -> Tensor: + if num_channels <= x.shape[-1]: + return x[..., :num_channels] + else: + shape = list(x.shape) + shape[-1] = num_channels - shape[-1] + zeros = torch.zeros(shape, dtype=x.dtype, device=x.device) + return torch.cat((x, zeros), dim=-1) + + +def _test_whiten(): + for proportion in [0.1, 0.5, 10.0]: + logging.info(f"_test_whiten(): proportion = {proportion}") + x = torch.randn(100, 128) + direction = torch.randn(128) + coeffs = torch.randn(100, 1) + x += proportion * direction * coeffs + + x.requires_grad = True + + m = Whiten(1, # num_groups + 5.0, # whitening_limit, + prob=1.0, + grad_scale=0.1) # grad_scale + + for _ in range(4): + y = m(x) + + y_grad = torch.randn_like(x) + y.backward(gradient=y_grad) + + if proportion < 0.2: + assert torch.allclose(x.grad, y_grad) + elif proportion > 1.0: + assert not torch.allclose(x.grad, y_grad) + + +def _test_balancer_sign(): + probs = torch.arange(0, 1, 0.01) + N = 1000 + x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0) + x = x.detach() + x.requires_grad = True + m = Balancer( + probs.numel(), + channel_dim=0, + min_positive=0.05, + max_positive=0.95, + min_abs=0.0, + prob=1.0, + ) + + y_grad = torch.sign(torch.randn(probs.numel(), N)) + + y = m(x) + y.backward(gradient=y_grad) + print("_test_balancer_sign: x = ", x) + print("_test_balancer_sign: y grad = ", y_grad) + print("_test_balancer_sign: x grad = ", x.grad) + + +def _test_balancer_magnitude(): + magnitudes = torch.arange(0, 1, 0.01) + N = 1000 + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( + -1 + ) + x = x.detach() + x.requires_grad = True + m = Balancer( + magnitudes.numel(), + channel_dim=0, + min_positive=0.0, + max_positive=1.0, + min_abs=0.2, + max_abs=0.7, + prob=1.0, + ) + + y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) + + y = m(x) + y.backward(gradient=y_grad) + print("_test_balancer_magnitude: x = ", x) + print("_test_balancer_magnitude: y grad = ", y_grad) + print("_test_balancer_magnitude: x grad = ", x.grad) + + +def _test_double_swish_deriv(): + x = torch.randn(10, 12, dtype=torch.double) * 3.0 + x.requires_grad = True + m = DoubleSwish() + + tol = ((1.2-(-0.043637))/255.0) + torch.autograd.gradcheck(m, x, atol=tol) + + # for self-test. + x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 + x.requires_grad = True + y = m(x) + + +def _test_swooshl_deriv(): + x = torch.randn(10, 12, dtype=torch.double) * 3.0 + x.requires_grad = True + m = SwooshL() + + tol = (1.0 / 255.0) + torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) + + # for self-test. + x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 + x.requires_grad = True + y = m(x) + + +def _test_swooshr_deriv(): + x = torch.randn(10, 12, dtype=torch.double) * 3.0 + x.requires_grad = True + m = SwooshR() + + tol = (1.0 / 255.0) + torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) + + # for self-test. + x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 + x.requires_grad = True + y = m(x) + + +def _test_softmax(): + a = torch.randn(2, 10, dtype=torch.float64) + b = a.clone() + a.requires_grad = True + b.requires_grad = True + a.softmax(dim=1)[:,0].sum().backward() + print("a grad = ", a.grad) + softmax(b, dim=1)[:,0].sum().backward() + print("b grad = ", b.grad) + assert torch.allclose(a.grad, b.grad) + + +def _test_piecewise_linear(): + p = PiecewiseLinear( (0, 10.0) ) + for x in [-100, 0, 100]: + assert p(x) == 10.0 + p = PiecewiseLinear( (0, 10.0), (1, 0.0) ) + for x, y in [ (-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0) ]: + print("x, y = ", x, y) + assert p(x) == y, (x, p(x), y) + + q = PiecewiseLinear((0.5, 15.0), (0.6, 1.0)) + x_vals = [ -1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0 ] + pq = p.max(q) + for x in x_vals: + y1 = max(p(x), q(x)) + y2 = pq(x) + assert abs(y1 - y2) < 0.001 + pq = p.min(q) + for x in x_vals: + y1 = min(p(x), q(x)) + y2 = pq(x) + assert abs(y1 - y2) < 0.001 + pq = p + q + for x in x_vals: + y1 = p(x) + q(x) + y2 = pq(x) + assert abs(y1 - y2) < 0.001 + + +def _test_activation_dropout_and_linear(): + in_channels = 20 + out_channels = 30 + + for bias in [True, False]: + # actually we don't test for dropout_p != 0.0 because forward functions will give + # different answers. This is because we are using the k2 implementation of + # swoosh_l an swoosh_r inside SwooshL() and SwooshR(), and they call randn() + # internally, messing up the random state. + for dropout_p in [0.0]: + for activation in ['SwooshL', 'SwooshR']: + m1 = nn.Sequential(SwooshL() if activation == 'SwooshL' else SwooshR(), + Dropout3(p=dropout_p, shared_dim=-1), + ScaledLinear(in_channels, out_channels, bias=bias, + initial_scale=0.5)) + m2 = ActivationDropoutAndLinear(in_channels, out_channels, + bias=bias, initial_scale=0.5, + activation=activation, + dropout_p=dropout_p) + with torch.no_grad(): + m2.weight[:] = m1[2].weight + if bias: + m2.bias[:] = m1[2].bias + # make sure forward gives same result. + x1 = torch.randn(10, in_channels) + x1.requires_grad = True + + # TEMP. + assert torch.allclose(SwooshRFunction.apply(x1), + SwooshRForward(x1), + atol=1.0e-03) + + x2 = x1.clone().detach() + x2.requires_grad = True + seed = 10 + torch.manual_seed(seed) + y1 = m1(x1) + y_grad = torch.randn_like(y1) + y1.backward(gradient=y_grad) + torch.manual_seed(seed) + y2 = m2(x2) + y2.backward(gradient=y_grad) + + print(f"bias = {bias}, dropout_p = {dropout_p}, activation = {activation}") + print("y1 = ", y1) + print("y2 = ", y2) + assert torch.allclose(y1, y2, atol=0.02) + assert torch.allclose(m1[2].weight.grad, m2.weight.grad, + atol=1.0e-05) + if bias: + assert torch.allclose(m1[2].bias.grad, m2.bias.grad, + atol=1.0e-05) + print("x1.grad = ", x1.grad) + print("x2.grad = ", x2.grad) + + def isclose(a, b): + # return true if cosine similarity is > 0.9. + return (a * b).sum() > 0.9 * ((a**2).sum() * (b**2).sum()).sqrt() + # the SwooshL() implementation has a noisy gradient due to 1-byte + # storage of it. + assert isclose(x1.grad, x2.grad) + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_piecewise_linear() + _test_softmax() + _test_whiten() + _test_balancer_sign() + _test_balancer_magnitude() + _test_double_swish_deriv() + _test_swooshr_deriv() + _test_swooshl_deriv() + _test_activation_dropout_and_linear() diff --git a/egs/multi_zh-hans/ASR/zipformer/scaling_converter.py b/egs/multi_zh-hans/ASR/zipformer/scaling_converter.py deleted file mode 120000 index b0ecee05e..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/scaling_converter.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/scaling_converter.py b/egs/multi_zh-hans/ASR/zipformer/scaling_converter.py new file mode 100644 index 000000000..76622fa12 --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/scaling_converter.py @@ -0,0 +1,104 @@ +# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file replaces various modules in a model. +Specifically, ActivationBalancer is replaced with an identity operator; +Whiten is also replaced with an identity operator; +BasicNorm is replaced by a module with `exp` removed. +""" + +import copy +from typing import List, Tuple + +import torch +import torch.nn as nn +from scaling import ( + Balancer, + Dropout3, + ScaleGrad, + SwooshL, + SwooshLOnnx, + SwooshR, + SwooshROnnx, + Whiten, +) +from zipformer import CompactRelPositionalEncoding + + +# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa +# get_submodule was added to nn.Module at v1.9.0 +def get_submodule(model, target): + if target == "": + return model + atoms: List[str] = target.split(".") + mod: torch.nn.Module = model + for item in atoms: + if not hasattr(mod, item): + raise AttributeError( + mod._get_name() + " has no " "attribute `" + item + "`" + ) + mod = getattr(mod, item) + if not isinstance(mod, torch.nn.Module): + raise AttributeError("`" + item + "` is not " "an nn.Module") + return mod + + +def convert_scaled_to_non_scaled( + model: nn.Module, + inplace: bool = False, + is_pnnx: bool = False, + is_onnx: bool = False, +): + """ + Args: + model: + The model to be converted. + inplace: + If True, the input model is modified inplace. + If False, the input model is copied and we modify the copied version. + is_pnnx: + True if we are going to export the model for PNNX. + is_onnx: + True if we are going to export the model for ONNX. + Return: + Return a model without scaled layers. + """ + if not inplace: + model = copy.deepcopy(model) + + d = {} + for name, m in model.named_modules(): + if isinstance(m, (Balancer, Dropout3, ScaleGrad, Whiten)): + d[name] = nn.Identity() + elif is_onnx and isinstance(m, SwooshR): + d[name] = SwooshROnnx() + elif is_onnx and isinstance(m, SwooshL): + d[name] = SwooshLOnnx() + elif is_onnx and isinstance(m, CompactRelPositionalEncoding): + # We want to recreate the positional encoding vector when + # the input changes, so we have to use torch.jit.script() + # to replace torch.jit.trace() + d[name] = torch.jit.script(m) + + for k, v in d.items(): + if "." in k: + parent, child = k.rsplit(".", maxsplit=1) + setattr(get_submodule(model, parent), child, v) + else: + setattr(model, k, v) + + return model diff --git a/egs/multi_zh-hans/ASR/zipformer/streaming_beam_search.py b/egs/multi_zh-hans/ASR/zipformer/streaming_beam_search.py deleted file mode 120000 index b1ed54557..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/streaming_beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/streaming_beam_search.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/streaming_beam_search.py b/egs/multi_zh-hans/ASR/zipformer/streaming_beam_search.py new file mode 100644 index 000000000..3c8565b33 --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/streaming_beam_search.py @@ -0,0 +1,295 @@ +# Copyright 2022 Xiaomi Corp. (authors: Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import List + +import k2 +import torch +import torch.nn as nn +from beam_search import Hypothesis, HypothesisList, get_hyps_shape +from decode_stream import DecodeStream + +from icefall.decode import one_best_decoding +from icefall.utils import get_texts + + +def greedy_search( + model: nn.Module, + encoder_out: torch.Tensor, + streams: List[DecodeStream], + blank_penalty: float = 0.0, +) -> None: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C), where N >= 1. + streams: + A list of Stream objects. + """ + assert len(streams) == encoder_out.size(0) + assert encoder_out.ndim == 3 + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = model.device + T = encoder_out.size(1) + + decoder_input = torch.tensor( + [stream.hyp[-context_size:] for stream in streams], + device=device, + dtype=torch.int64, + ) + # decoder_out is of shape (N, 1, decoder_out_dim) + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + for t in range(T): + # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) + current_encoder_out = encoder_out[:, t : t + 1, :] # noqa + + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + # logits'shape (batch_size, vocab_size) + logits = logits.squeeze(1).squeeze(1) + + if blank_penalty != 0.0: + logits[:, 0] -= blank_penalty + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + streams[i].hyp.append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = torch.tensor( + [stream.hyp[-context_size:] for stream in streams], + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, + need_pad=False, + ) + decoder_out = model.joiner.decoder_proj(decoder_out) + + +def modified_beam_search( + model: nn.Module, + encoder_out: torch.Tensor, + streams: List[DecodeStream], + num_active_paths: int = 4, + blank_penalty: float = 0.0, +) -> None: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + + Args: + model: + The RNN-T model. + encoder_out: + A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of + the encoder model. + streams: + A list of stream objects. + num_active_paths: + Number of active paths during the beam search. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert len(streams) == encoder_out.size(0) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = next(model.parameters()).device + batch_size = len(streams) + T = encoder_out.size(1) + + B = [stream.hyps for stream in streams] + + for t in range(T): + current_encoder_out = encoder_out[:, t].unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.stack( + [hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0 + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, decoder_output_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, encoder_out_dim) + + logits = model.joiner(current_encoder_out, decoder_out, project_input=False) + # logits is of shape (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) + + if blank_penalty != 0.0: + logits[:, 0] -= blank_penalty + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + if new_token != blank_id: + new_ys.append(new_token) + + new_log_prob = topk_log_probs[k] + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + B[i].add(new_hyp) + + for i in range(batch_size): + streams[i].hyps = B[i] + + +def fast_beam_search_one_best( + model: nn.Module, + encoder_out: torch.Tensor, + processed_lens: torch.Tensor, + streams: List[DecodeStream], + beam: float, + max_states: int, + max_contexts: int, + blank_penalty: float = 0.0, +) -> None: + """It limits the maximum number of symbols per frame to 1. + + A lattice is first generated by Fsa-based beam search, then we get the + recognition by applying shortest path on the lattice. + + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + processed_lens: + A tensor of shape (N,) containing the number of processed frames + in `encoder_out` before padding. + streams: + A list of stream objects. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + """ + assert encoder_out.ndim == 3 + B, T, C = encoder_out.shape + assert B == len(streams) + + context_size = model.decoder.context_size + vocab_size = model.decoder.vocab_size + + config = k2.RnntDecodingConfig( + vocab_size=vocab_size, + decoder_history_len=context_size, + beam=beam, + max_contexts=max_contexts, + max_states=max_states, + ) + individual_streams = [] + for i in range(B): + individual_streams.append(streams[i].rnnt_decoding_stream) + decoding_streams = k2.RnntDecodingStreams(individual_streams, config) + + for t in range(T): + # shape is a RaggedShape of shape (B, context) + # contexts is a Tensor of shape (shape.NumElements(), context_size) + shape, contexts = decoding_streams.get_contexts() + # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 + contexts = contexts.to(torch.int64) + # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) + decoder_out = model.decoder(contexts, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + # current_encoder_out is of shape + # (shape.NumElements(), 1, joiner_dim) + # fmt: off + current_encoder_out = torch.index_select( + encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) + ) + # fmt: on + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + logits = logits.squeeze(1).squeeze(1) + + if blank_penalty != 0.0: + logits[:, 0] -= blank_penalty + + log_probs = logits.log_softmax(dim=-1) + decoding_streams.advance(log_probs) + + decoding_streams.terminate_and_flush_to_streams() + + lattice = decoding_streams.format_output(processed_lens.tolist()) + best_path = one_best_decoding(lattice) + hyp_tokens = get_texts(best_path) + + for i in range(B): + streams[i].hyp = hyp_tokens[i] diff --git a/egs/multi_zh-hans/ASR/zipformer/streaming_decode.py b/egs/multi_zh-hans/ASR/zipformer/streaming_decode.py index 94c5fae5f..44ff392a3 100755 --- a/egs/multi_zh-hans/ASR/zipformer/streaming_decode.py +++ b/egs/multi_zh-hans/ASR/zipformer/streaming_decode.py @@ -23,7 +23,7 @@ Usage: --epoch 28 \ --avg 15 \ --causal 1 \ - --chunk-size 16 \ + --chunk-size 32 \ --left-context-frames 256 \ --exp-dir ./zipformer/exp \ --decoding-method greedy_search \ @@ -38,8 +38,9 @@ from typing import Dict, List, Optional, Tuple import k2 import numpy as np +import sentencepiece as spm import torch -from asr_datamodule import WenetSpeechAsrDataModule +from asr_datamodule import LibriSpeechAsrDataModule from decode_stream import DecodeStream from kaldifeat import Fbank, FbankOptions from lhotse import CutSet @@ -50,7 +51,7 @@ from streaming_beam_search import ( ) from torch import Tensor, nn from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_model, get_params +from train import add_model_arguments, get_params, get_model from icefall.checkpoint import ( average_checkpoints, @@ -58,7 +59,6 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, make_pad_mask, @@ -123,10 +123,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--bpe-model", type=str, - default="data/lang_char", - help="Path to the lang dir(containing lexicon, tokens, etc.)", + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", ) parser.add_argument( @@ -181,18 +181,6 @@ def get_parser(): help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) - parser.add_argument( - "--blank-penalty", - type=float, - default=0.0, - help=""" - The penalty applied on blank symbol during decoding. - Note: It is a positive value that would be applied to logits like - this `logits[:, 0] -= blank_penalty` (suppose logits.shape is - [batch_size, vocab] and blank id is 0). - """, - ) - parser.add_argument( "--num-decode-streams", type=int, @@ -294,7 +282,9 @@ def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: ) batch_states.append(cached_embed_left_pad) - processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) + processed_lens = torch.cat( + [state_list[i][-1] for i in range(batch_size)], dim=0 + ) batch_states.append(processed_lens) return batch_states @@ -332,7 +322,9 @@ def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: for layer in range(tot_num_layers): layer_offset = layer * 6 # cached_key: (left_context_len, batch_size, key_dim) - cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) + cached_key_list = batch_states[layer_offset].chunk( + chunks=batch_size, dim=1 + ) # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( chunks=batch_size, dim=1 @@ -363,7 +355,9 @@ def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: cached_conv2_list[i], ] - cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) + cached_embed_left_pad_list = batch_states[-2].chunk( + chunks=batch_size, dim=0 + ) for i in range(batch_size): state_list[i].append(cached_embed_left_pad_list[i]) @@ -386,7 +380,11 @@ def streaming_forward( Returns encoder outputs, output lengths, and updated states. """ cached_embed_left_pad = states[-2] - (x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward( + ( + x, + x_lens, + new_cached_embed_left_pad, + ) = model.encoder_embed.streaming_forward( x=features, x_lens=feature_lens, cached_left_pad=cached_embed_left_pad, @@ -406,7 +404,9 @@ def streaming_forward( new_processed_lens = processed_lens + x_lens # (batch, left_context_size + chunk_size) - src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + src_key_padding_mask = torch.cat( + [processed_mask, src_key_padding_mask], dim=1 + ) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) encoder_states = states[:-2] @@ -495,10 +495,7 @@ def decode_one_chunk( if params.decoding_method == "greedy_search": greedy_search( - model=model, - encoder_out=encoder_out, - streams=decode_streams, - blank_penalty=params.blank_penalty, + model=model, encoder_out=encoder_out, streams=decode_streams ) elif params.decoding_method == "fast_beam_search": processed_lens = torch.tensor(processed_lens, device=device) @@ -511,7 +508,6 @@ def decode_one_chunk( beam=params.beam, max_states=params.max_states, max_contexts=params.max_contexts, - blank_penalty=params.blank_penalty, ) elif params.decoding_method == "modified_beam_search": modified_beam_search( @@ -519,10 +515,11 @@ def decode_one_chunk( streams=decode_streams, encoder_out=encoder_out, num_active_paths=params.num_active_paths, - blank_penalty=params.blank_penalty, ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) states = unstack_states(new_states) @@ -540,7 +537,7 @@ def decode_dataset( cuts: CutSet, params: AttributeDict, model: nn.Module, - lexicon: Lexicon, + sp: spm.SentencePieceProcessor, decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -552,8 +549,8 @@ def decode_dataset( It is returned by :func:`get_params`. model: The neural model. - lexicon: - The Lexicon. + sp: + The BPE model. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used only when --decoding_method is fast_beam_search. @@ -580,7 +577,9 @@ def decode_dataset( decode_streams = [] for num, cut in enumerate(cuts): # each utterance has a DecodeStream. - initial_states = get_init_states(model=model, batch_size=1, device=device) + initial_states = get_init_states( + model=model, batch_size=1, device=device + ) decode_stream = DecodeStream( params=params, cut_id=cut.id, @@ -596,12 +595,7 @@ def decode_dataset( assert audio.dtype == np.float32, audio.dtype # The trained model is using normalized samples - if audio.max() > 1: - logging.warning( - f"The audio should be normalized to [-1, 1], audio.max : {audio.max()}." - f"Clipping to [-1, 1]." - ) - audio = np.clip(audio, -1, 1) + assert audio.max() <= 1, "Should be normalized to [-1, 1])" samples = torch.from_numpy(audio).squeeze(0) @@ -620,11 +614,8 @@ def decode_dataset( decode_results.append( ( decode_streams[i].id, - list(decode_streams[i].ground_truth.strip()), - [ - lexicon.token_table[idx] - for idx in decode_streams[i].decoding_result() - ], + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), ) ) del decode_streams[i] @@ -642,27 +633,25 @@ def decode_dataset( ( decode_streams[i].id, decode_streams[i].ground_truth.split(), - [ - lexicon.token_table[idx] - for idx in decode_streams[i].decoding_result() - ], + sp.decode(decode_streams[i].decoding_result()).split(), ) ) del decode_streams[i] - key = f"blank_penalty_{params.blank_penalty}" if params.decoding_method == "greedy_search": - key = f"greedy_search_{key}" + key = "greedy_search" elif params.decoding_method == "fast_beam_search": key = ( f"beam_{params.beam}_" f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}_{key}" + f"max_states_{params.max_states}" ) elif params.decoding_method == "modified_beam_search": - key = f"num_active_paths_{params.num_active_paths}_{key}" + key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) return {key: decode_results} @@ -695,7 +684,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -713,7 +703,7 @@ def save_results( @torch.no_grad() def main(): parser = get_parser() - WenetSpeechAsrDataModule.add_arguments(parser) + LibriSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -728,13 +718,14 @@ def main(): params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" assert params.causal, params.causal - assert "," not in params.chunk_size, "chunk_size should be one value in decoding." + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." assert ( "," not in params.left_context_frames ), "left_context_frames should be one value in decoding." params.suffix += f"-chunk-{params.chunk_size}" params.suffix += f"-left-context-{params.left_context_frames}" - params.suffix += f"-blank-penalty-{params.blank_penalty}" # for fast_beam_search if params.decoding_method == "fast_beam_search": @@ -754,9 +745,13 @@ def main(): logging.info(f"Device: {device}") - lexicon = Lexicon(params.lang_dir) - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() logging.info(params) @@ -765,9 +760,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -794,9 +789,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -851,23 +846,23 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - wenetspeech = WenetSpeechAsrDataModule(args) + librispeech = LibriSpeechAsrDataModule(args) - dev_cuts = wenetspeech.valid_cuts() - test_net_cuts = wenetspeech.test_net_cuts() - test_meeting_cuts = wenetspeech.test_meeting_cuts() + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() - test_sets = ["DEV", "TEST_NET", "TEST_MEETING"] - test_cuts = [dev_cuts, test_net_cuts, test_meeting_cuts] + test_sets = ["test-clean", "test-other"] + test_cuts = [test_clean_cuts, test_other_cuts] for test_set, test_cut in zip(test_sets, test_cuts): results_dict = decode_dataset( cuts=test_cut, params=params, model=model, - lexicon=lexicon, + sp=sp, decoding_graph=decoding_graph, ) + save_results( params=params, test_set_name=test_set, diff --git a/egs/multi_zh-hans/ASR/zipformer/subsampling.py b/egs/multi_zh-hans/ASR/zipformer/subsampling.py deleted file mode 120000 index 01ae9002c..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/subsampling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/subsampling.py b/egs/multi_zh-hans/ASR/zipformer/subsampling.py new file mode 100644 index 000000000..d6bf57db4 --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/subsampling.py @@ -0,0 +1,407 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Daniel Povey, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple +import warnings + +import torch +from torch import Tensor, nn +from scaling import ( + Balancer, + BiasNorm, + Dropout3, + FloatLike, + Optional, + ScaledConv2d, + ScaleGrad, + ScheduledFloat, + SwooshL, + SwooshR, + Whiten, +) + + +class ConvNeXt(nn.Module): + """ + Our interpretation of the ConvNeXt module as used in https://arxiv.org/pdf/2206.14747.pdf + """ + + def __init__( + self, + channels: int, + hidden_ratio: int = 3, + kernel_size: Tuple[int, int] = (7, 7), + layerdrop_rate: FloatLike = None, + ): + super().__init__() + self.padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2) + hidden_channels = channels * hidden_ratio + if layerdrop_rate is None: + layerdrop_rate = ScheduledFloat((0.0, 0.2), (20000.0, 0.015)) + self.layerdrop_rate = layerdrop_rate + + self.depthwise_conv = nn.Conv2d( + in_channels=channels, + out_channels=channels, + groups=channels, + kernel_size=kernel_size, + padding=self.padding, + ) + + self.pointwise_conv1 = nn.Conv2d( + in_channels=channels, out_channels=hidden_channels, kernel_size=1 + ) + + self.hidden_balancer = Balancer( + hidden_channels, + channel_dim=1, + min_positive=0.3, + max_positive=1.0, + min_abs=0.75, + max_abs=5.0, + ) + + self.activation = SwooshL() + self.pointwise_conv2 = ScaledConv2d( + in_channels=hidden_channels, + out_channels=channels, + kernel_size=1, + initial_scale=0.01, + ) + + self.out_balancer = Balancer( + channels, + channel_dim=1, + min_positive=0.4, + max_positive=0.6, + min_abs=1.0, + max_abs=6.0, + ) + self.out_whiten = Whiten( + num_groups=1, + whitening_limit=5.0, + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward(self, x: Tensor) -> Tensor: + if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: + return self.forward_internal(x) + layerdrop_rate = float(self.layerdrop_rate) + + if layerdrop_rate != 0.0: + batch_size = x.shape[0] + mask = ( + torch.rand( + (batch_size, 1, 1, 1), dtype=x.dtype, device=x.device + ) + > layerdrop_rate + ) + else: + mask = None + # turns out this caching idea does not work with --world-size > 1 + # return caching_eval(self.forward_internal, x, mask) + return self.forward_internal(x, mask) + + def forward_internal( + self, x: Tensor, layer_skip_mask: Optional[Tensor] = None + ) -> Tensor: + """ + x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs) + + The returned value has the same shape as x. + """ + bypass = x + x = self.depthwise_conv(x) + x = self.pointwise_conv1(x) + x = self.hidden_balancer(x) + x = self.activation(x) + x = self.pointwise_conv2(x) + + if layer_skip_mask is not None: + x = x * layer_skip_mask + + x = bypass + x + x = self.out_balancer(x) + x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last + x = self.out_whiten(x) + x = x.transpose(1, 3) # (N, C, H, W) + + return x + + def streaming_forward( + self, + x: Tensor, + cached_left_pad: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs) + cached_left_pad: (batch_size, num_channels, left_pad, num_freqs) + + Returns: + - The returned value has the same shape as x. + - Updated cached_left_pad. + """ + padding = self.padding + + # The length without right padding for depth-wise conv + T = x.size(2) - padding[0] + + bypass = x[:, :, :T, :] + + # Pad left side + assert cached_left_pad.size(2) == padding[0], ( + cached_left_pad.size(2), + padding[0], + ) + x = torch.cat([cached_left_pad, x], dim=2) + # Update cached left padding + cached_left_pad = x[:, :, T : padding[0] + T, :] + + # depthwise_conv + x = torch.nn.functional.conv2d( + x, + weight=self.depthwise_conv.weight, + bias=self.depthwise_conv.bias, + padding=(0, padding[1]), + groups=self.depthwise_conv.groups, + ) + x = self.pointwise_conv1(x) + x = self.hidden_balancer(x) + x = self.activation(x) + x = self.pointwise_conv2(x) + + x = bypass + x + return x, cached_left_pad + + +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/2 length). + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where + T' = (T-3)//2 - 2 == (T-7)//2 + + It is based on + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + layer1_channels: int = 8, + layer2_channels: int = 32, + layer3_channels: int = 128, + dropout: FloatLike = 0.1, + ) -> None: + """ + Args: + in_channels: + Number of channels in. The input shape is (N, T, in_channels). + Caution: It requires: T >=7, in_channels >=7 + out_channels + Output dim. The output shape is (N, (T-3)//2, out_channels) + layer1_channels: + Number of channels in layer1 + layer1_channels: + Number of channels in layer2 + bottleneck: + bottleneck dimension for 1d squeeze-excite + """ + assert in_channels >= 7 + super().__init__() + + # The ScaleGrad module is there to prevent the gradients + # w.r.t. the weight or bias of the first Conv2d module in self.conv from + # exceeding the range of fp16 when using automatic mixed precision (amp) + # training. (The second one is necessary to stop its bias from getting + # a too-large gradient). + + self.conv = nn.Sequential( + nn.Conv2d( + in_channels=1, + out_channels=layer1_channels, + kernel_size=3, + padding=(0, 1), # (time, freq) + ), + ScaleGrad(0.2), + Balancer(layer1_channels, channel_dim=1, max_abs=1.0), + SwooshR(), + nn.Conv2d( + in_channels=layer1_channels, + out_channels=layer2_channels, + kernel_size=3, + stride=2, + padding=0, + ), + Balancer(layer2_channels, channel_dim=1, max_abs=4.0), + SwooshR(), + nn.Conv2d( + in_channels=layer2_channels, + out_channels=layer3_channels, + kernel_size=3, + stride=(1, 2), # (time, freq) + ), + Balancer(layer3_channels, channel_dim=1, max_abs=4.0), + SwooshR(), + ) + + # just one convnext layer + self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7)) + + self.out_width = (((in_channels - 1) // 2) - 1) // 2 + self.layer3_channels = layer3_channels + + self.out = nn.Linear(self.out_width * layer3_channels, out_channels) + # use a larger than normal grad_scale on this whitening module; there is + # only one such module, so there is not a concern about adding together + # many copies of this extra gradient term. + self.out_whiten = Whiten( + num_groups=1, + whitening_limit=ScheduledFloat( + (0.0, 4.0), (20000.0, 8.0), default=4.0 + ), + prob=(0.025, 0.25), + grad_scale=0.02, + ) + + # max_log_eps=0.0 is to prevent both eps and the output of self.out from + # getting large, there is an unnecessary degree of freedom. + self.out_norm = BiasNorm(out_channels) + self.dropout = Dropout3(dropout, shared_dim=1) + + def forward( + self, x: torch.Tensor, x_lens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + + Returns: + - a tensor of shape (N, ((T-1)//2 - 1)//2, odim) + - output lengths, of shape (batch_size,) + """ + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + # scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision) + # training, since the weights in the first convolution are otherwise the limiting factor for getting infinite + # gradients. + x = self.conv(x) + x = self.convnext(x) + + # Now x is of shape (N, odim, ((T-3)//2 - 1)//2, ((idim-1)//2 - 1)//2) + b, c, t, f = x.size() + + x = x.transpose(1, 2).reshape(b, t, c * f) + # now x: (N, ((T-1)//2 - 1))//2, out_width * layer3_channels)) + + x = self.out(x) + # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) + x = self.out_whiten(x) + x = self.out_norm(x) + x = self.dropout(x) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + x_lens = (x_lens - 7) // 2 + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + x_lens = (x_lens - 7) // 2 + assert x.size(1) == x_lens.max().item() + + return x, x_lens + + def streaming_forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + cached_left_pad: Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + + Returns: + - a tensor of shape (N, ((T-1)//2 - 1)//2, odim) + - output lengths, of shape (batch_size,) + - updated cache + """ + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + + # T' = (T-7)//2 + x = self.conv(x) + + # T' = (T-7)//2-3 + x, cached_left_pad = self.convnext.streaming_forward( + x, cached_left_pad=cached_left_pad + ) + + # Now x is of shape (N, odim, T', ((idim-1)//2 - 1)//2) + b, c, t, f = x.size() + + x = x.transpose(1, 2).reshape(b, t, c * f) + # now x: (N, T', out_width * layer3_channels)) + + x = self.out(x) + # Now x is of shape (N, T', odim) + x = self.out_norm(x) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + assert self.convnext.padding[0] == 3 + # The ConvNeXt module needs 3 frames of right padding after subsampling + x_lens = (x_lens - 7) // 2 - 3 + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # The ConvNeXt module needs 3 frames of right padding after subsampling + assert self.convnext.padding[0] == 3 + x_lens = (x_lens - 7) // 2 - 3 + + assert x.size(1) == x_lens.max().item() + + return x, x_lens, cached_left_pad + + @torch.jit.export + def get_init_states( + self, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), + ) -> Tensor: + """Get initial states for Conv2dSubsampling module. + It is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + """ + left_pad = self.convnext.padding[0] + freq = self.out_width + channels = self.layer3_channels + cached_embed_left_pad = torch.zeros( + batch_size, channels, left_pad, freq + ).to(device) + + return cached_embed_left_pad diff --git a/egs/multi_zh-hans/ASR/zipformer/train.py b/egs/multi_zh-hans/ASR/zipformer/train.py index 83dbfa22f..bc3e9c1ba 100755 --- a/egs/multi_zh-hans/ASR/zipformer/train.py +++ b/egs/multi_zh-hans/ASR/zipformer/train.py @@ -21,29 +21,33 @@ """ Usage: -export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" +export CUDA_VISIBLE_DEVICES="0,1,2,3" +# For non-streaming model training: ./zipformer/train.py \ - --world-size 8 \ - --num-epochs 12 \ - --start-epoch 1 \ - --exp-dir zipformer/exp \ - --training-subset L - --lr-epochs 1.5 \ - --max-duration 350 - -# For mix precision training: - -./zipformer/train.py \ - --world-size 8 \ - --num-epochs 12 \ + --world-size 4 \ + --num-epochs 30 \ --start-epoch 1 \ --use-fp16 1 \ --exp-dir zipformer/exp \ - --training-subset L \ - --lr-epochs 1.5 \ - --max-duration 750 + --full-libri 1 \ + --max-duration 1000 +# For streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --full-libri 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` """ @@ -57,10 +61,11 @@ from typing import Any, Dict, Optional, Tuple, Union import k2 import optim +import sentencepiece as spm import torch import torch.multiprocessing as mp import torch.nn as nn -from asr_datamodule import WenetSpeechAsrDataModule +from asr_datamodule import LibriSpeechAsrDataModule from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut @@ -77,7 +82,6 @@ from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 from icefall import diagnostics -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import ( @@ -87,7 +91,6 @@ from icefall.checkpoint import ( from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks -from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, MetricsTracker, @@ -139,42 +142,42 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--feedforward-dim", type=str, default="512,768,1024,1536,1024,768", - help="""Feedforward dimension of the zipformer encoder layers, per stack, comma separated.""", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", ) parser.add_argument( "--num-heads", type=str, default="4,4,4,8,4,4", - help="""Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.""", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", ) parser.add_argument( "--encoder-dim", type=str, default="192,256,384,512,384,256", - help="""Embedding dimension in encoder stacks: a single int or comma-separated list.""", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", ) parser.add_argument( "--query-head-dim", type=str, default="32", - help="""Query/key dimension per head in encoder stacks: a single int or comma-separated list.""", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", ) parser.add_argument( "--value-head-dim", type=str, default="12", - help="""Value dimension per head in encoder stacks: a single int or comma-separated list.""", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", ) parser.add_argument( "--pos-head-dim", type=str, default="4", - help="""Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.""", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", ) parser.add_argument( @@ -188,14 +191,16 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--encoder-unmasked-dim", type=str, default="192,192,256,256,256,192", - help="""Unmasked dimensions in the encoders, relates to augmentation during training. A single int or comma-separated list. Must be <= each corresponding encoder_dim.""", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", ) parser.add_argument( "--cnn-module-kernel", type=str, default="31,31,15,15,15,31", - help="""Sizes of convolutional kernels in convolution modules in each encoder stack: a single int or comma-separated list.""", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", ) parser.add_argument( @@ -226,16 +231,31 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--chunk-size", type=str, default="16,32,64,-1", - help="""Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. Must be just -1 if --causal=False""", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", ) parser.add_argument( "--left-context-frames", type=str, default="64,128,256,-1", - help="""Maximum left-contexts for causal training, measured in frames which will - be converted to a number of chunks. If splitting into chunks, - chunk left-context frames will be chosen randomly from this list; else not relevant.""", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", ) @@ -302,13 +322,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--bpe-model", type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", ) parser.add_argument( @@ -335,47 +352,55 @@ def get_parser(): "--ref-duration", type=float, default=600, - help="""Reference batch duration for purposes of adjusting batch counts for setting various schedules inside the model""", + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", ) parser.add_argument( "--context-size", type=int, default=2, - help="""The context size in the decoder. 1 means bigram; 2 means tri-gram""", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="""The prune range for rnnt loss, it means how many symbols(context) - we are using to compute the loss""", + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="""The scale to smooth the loss with lm - (output of prediction network) part.""", + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="""The scale to smooth the loss with am (output of encoder network) part.""", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="""To get pruning ranges, we will calculate a simple version - loss(joiner is just addition), this simple loss also uses for - training (as a regularization item). We will scale the simple loss - with this parameter before adding to the final loss.""", + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", ) parser.add_argument( @@ -408,7 +433,7 @@ def get_parser(): params.batch_idx_train % save_every_n == 0. The checkpoint filename has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. + end of each epoch where `xxx` is the epoch number counting from 1. """, ) @@ -502,7 +527,7 @@ def get_params() -> AttributeDict: "batch_idx_train": 0, "log_interval": 50, "reset_interval": 200, - "valid_interval": 3000, + "valid_interval": 3000, # For the 100h subset, use 800 # parameters for zipformer "feature_dim": 80, "subsampling_factor": 4, # not passed in, this is fixed. @@ -579,19 +604,32 @@ def get_joiner_model(params: AttributeDict) -> nn.Module: def get_model(params: AttributeDict) -> nn.Module: + assert ( + params.use_transducer or params.use_ctc + ), (f"At least one of them should be True, " + f"but got params.use_transducer={params.use_transducer}, " + f"params.use_ctc={params.use_ctc}") + encoder_embed = get_encoder_embed(params) encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None model = AsrModel( encoder_embed=encoder_embed, encoder=encoder, decoder=decoder, joiner=joiner, - encoder_dim=int(max(params.encoder_dim.split(","))), + encoder_dim=max(_to_int_tuple(params.encoder_dim)), decoder_dim=params.decoder_dim, vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, ) return model @@ -659,9 +697,6 @@ def load_checkpoint_if_available( if "cur_epoch" in saved_params: params["start_epoch"] = saved_params["cur_epoch"] - if "cur_batch_idx" in saved_params: - params["cur_batch_idx"] = saved_params["cur_batch_idx"] - return saved_params @@ -718,12 +753,12 @@ def save_checkpoint( def compute_loss( params: AttributeDict, model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, + sp: spm.SentencePieceProcessor, batch: dict, is_training: bool, ) -> Tuple[Tensor, MetricsTracker]: """ - Compute CTC loss given the model and its inputs. + Compute loss given the model and its inputs. Args: params: @@ -753,11 +788,11 @@ def compute_loss( warm_step = params.warm_step texts = batch["supervisions"]["text"] - y = graph_compiler.texts_to_ids(texts) - y = k2.RaggedTensor(y).to(device) + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, _ = model( + simple_loss, pruned_loss, ctc_loss = model( x=feature, x_lens=feature_lens, y=y, @@ -766,21 +801,27 @@ def compute_loss( lm_scale=params.lm_scale, ) - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) + loss = 0.0 - loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss += ( + simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss + ) + + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss assert loss.requires_grad == is_training @@ -791,8 +832,11 @@ def compute_loss( # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() return loss, info @@ -800,7 +844,7 @@ def compute_loss( def compute_validation_loss( params: AttributeDict, model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, + sp: spm.SentencePieceProcessor, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, ) -> MetricsTracker: @@ -813,7 +857,7 @@ def compute_validation_loss( loss, loss_info = compute_loss( params=params, model=model, - graph_compiler=graph_compiler, + sp=sp, batch=batch, is_training=False, ) @@ -836,7 +880,7 @@ def train_one_epoch( model: Union[nn.Module, DDP], optimizer: torch.optim.Optimizer, scheduler: LRSchedulerType, - graph_compiler: CharCtcTrainingGraphCompiler, + sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, scaler: GradScaler, @@ -880,8 +924,6 @@ def train_one_epoch( tot_loss = MetricsTracker() - cur_batch_idx = params.get("cur_batch_idx", 0) - saved_bad_model = False def save_bad_model(suffix: str = ""): @@ -900,9 +942,6 @@ def train_one_epoch( for batch_idx, batch in enumerate(train_dl): if batch_idx % 10 == 0: set_batch_count(model, get_adjusted_batch_count(params)) - if batch_idx < cur_batch_idx: - continue - cur_batch_idx = batch_idx params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) @@ -912,7 +951,7 @@ def train_one_epoch( loss, loss_info = compute_loss( params=params, model=model, - graph_compiler=graph_compiler, + sp=sp, batch=batch, is_training=True, ) @@ -929,7 +968,7 @@ def train_one_epoch( optimizer.zero_grad() except: # noqa save_bad_model() - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + display_and_save_batch(batch, params=params, sp=sp) raise if params.print_diagnostics and batch_idx == 5: @@ -950,7 +989,6 @@ def train_one_epoch( params.batch_idx_train > 0 and params.batch_idx_train % params.save_every_n == 0 ): - params.cur_batch_idx = batch_idx save_checkpoint_with_global_batch_idx( out_dir=params.exp_dir, global_batch_idx=params.batch_idx_train, @@ -963,7 +1001,6 @@ def train_one_epoch( scaler=scaler, rank=rank, ) - del params.cur_batch_idx remove_checkpoints( out_dir=params.exp_dir, topk=params.keep_last_k, @@ -1020,7 +1057,7 @@ def train_one_epoch( valid_info = compute_validation_loss( params=params, model=model, - graph_compiler=graph_compiler, + sp=sp, valid_dl=valid_dl, world_size=world_size, ) @@ -1073,14 +1110,15 @@ def run(rank, world_size, args): device = torch.device("cuda", rank) logging.info(f"Device: {device}") - lexicon = Lexicon(params.lang_dir) - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - ) + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if not params.use_transducer: + params.ctc_loss_scale = 1.0 logging.info(params) @@ -1135,23 +1173,25 @@ def run(rank, world_size, args): if params.inf_check: register_inf_check_hooks(model) - wenetspeech = WenetSpeechAsrDataModule(args) + librispeech = LibriSpeechAsrDataModule(args) - train_cuts = wenetspeech.train_cuts() - valid_cuts = wenetspeech.valid_cuts() + train_cuts = librispeech.train_clean_100_cuts() + if params.full_libri: + train_cuts += librispeech.train_clean_360_cuts() + train_cuts += librispeech.train_other_500_cuts() def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 15 seconds + # Keep only utterances with duration between 1 second and 20 seconds # - # Caution: There is a reason to select 15.0 here. Please see + # Caution: There is a reason to select 20.0 here. Please see # ../local/display_manifest_statistics.py # # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold - if c.duration < 1.0 or c.duration > 15.0: + if c.duration < 1.0 or c.duration > 20.0: # logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" # ) return False @@ -1162,7 +1202,7 @@ def run(rank, world_size, args): # In ./zipformer.py, the conv module uses the following expression # for subsampling T = ((c.num_frames - 7) // 2 + 1) // 2 - tokens = graph_compiler.texts_to_ids([c.supervisions[0].text])[0] + tokens = sp.encode(c.supervisions[0].text, out_type=str) if T < len(tokens): logging.warning( @@ -1186,18 +1226,20 @@ def run(rank, world_size, args): else: sampler_state_dict = None - train_dl = wenetspeech.train_dataloaders( + train_dl = librispeech.train_dataloaders( train_cuts, sampler_state_dict=sampler_state_dict ) - valid_dl = wenetspeech.valid_dataloaders(valid_cuts) + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) - if False and not params.print_diagnostics: + if not params.print_diagnostics: scan_pessimistic_batches_for_oom( model=model, train_dl=train_dl, optimizer=optimizer, - graph_compiler=graph_compiler, + sp=sp, params=params, ) @@ -1222,7 +1264,7 @@ def run(rank, world_size, args): model_avg=model_avg, optimizer=optimizer, scheduler=scheduler, - graph_compiler=graph_compiler, + sp=sp, train_dl=train_dl, valid_dl=valid_dl, scaler=scaler, @@ -1256,7 +1298,7 @@ def run(rank, world_size, args): def display_and_save_batch( batch: dict, params: AttributeDict, - graph_compiler: CharCtcTrainingGraphCompiler, + sp: spm.SentencePieceProcessor, ) -> None: """Display the batch statistics and save the batch into disk. @@ -1266,8 +1308,8 @@ def display_and_save_batch( for the content in it. params: Parameters for training. See :func:`get_params`. - graph_compiler: - The compiler to encode texts to ids. + sp: + The BPE model. """ from lhotse.utils import uuid4 @@ -1280,8 +1322,7 @@ def display_and_save_batch( logging.info(f"features shape: {features.shape}") - texts = supervisions["text"] - y = graph_compiler.texts_to_ids(texts) + y = sp.encode(supervisions["text"], out_type=int) num_tokens = sum(len(i) for i in y) logging.info(f"num tokens: {num_tokens}") @@ -1290,7 +1331,7 @@ def scan_pessimistic_batches_for_oom( model: Union[nn.Module, DDP], train_dl: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, - graph_compiler: CharCtcTrainingGraphCompiler, + sp: spm.SentencePieceProcessor, params: AttributeDict, ): from lhotse.dataset import find_pessimistic_batches @@ -1306,7 +1347,7 @@ def scan_pessimistic_batches_for_oom( loss, _ = compute_loss( params=params, model=model, - graph_compiler=graph_compiler, + sp=sp, batch=batch, is_training=True, ) @@ -1321,7 +1362,7 @@ def scan_pessimistic_batches_for_oom( f"Failing criterion: {criterion} " f"(={crit_values[criterion]}) ..." ) - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + display_and_save_batch(batch, params=params, sp=sp) raise logging.info( f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" @@ -1330,9 +1371,8 @@ def scan_pessimistic_batches_for_oom( def main(): parser = get_parser() - WenetSpeechAsrDataModule.add_arguments(parser) + LibriSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() - args.lang_dir = Path(args.lang_dir) args.exp_dir = Path(args.exp_dir) world_size = args.world_size diff --git a/egs/multi_zh-hans/ASR/zipformer/zipformer.py b/egs/multi_zh-hans/ASR/zipformer/zipformer.py deleted file mode 120000 index 23011dda7..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/zipformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/zipformer.py b/egs/multi_zh-hans/ASR/zipformer/zipformer.py new file mode 100644 index 000000000..7d98dbeb1 --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/zipformer.py @@ -0,0 +1,2254 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import math +import warnings +from typing import List, Optional, Tuple, Union +import logging +import torch +import random +from encoder_interface import EncoderInterface +from scaling import ( + Balancer, + BiasNorm, + Dropout2, + ChunkCausalDepthwiseConv1d, + ActivationDropoutAndLinear, + ScaledLinear, # not as in other dirs.. just scales down initial parameter values. + Whiten, + Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. + penalize_abs_values_gt, + softmax, + ScheduledFloat, + FloatLike, + limit_param_value, + convert_num_channels, +) +from torch import Tensor, nn + + +class Zipformer2(EncoderInterface): + """ + Args: + + Note: all "int or Tuple[int]" arguments below will be treated as lists of the same length + as downsampling_factor if they are single ints or one-element tuples. The length of + downsampling_factor defines the number of stacks. + + output_downsampling_factor (int): how much to downsample at the output. Note: + we also downsample by a factor of 2 in the Conv2dSubsampling encoder. + You should probably leave this at 2. + downsampling_factor (Tuple[int]): downsampling factor for each encoder stack. + Note: this is in addition to the downsampling factor of 2 that is applied in + the frontend (self.encoder_embed). + encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per + encoder stack. + num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack + encoder_unmasked_dim (int or Tuple[int]): unmasked dimension in each of + the encoder stacks for purposes of per-frame dropout (recommend 256 for + now). + query_head_dim (int or Tuple[int]): dimension of query and key per attention + head: per stack, if a tuple.. + pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection per + attention head + value_head_dim (int or Tuple[int]): dimension of value in each attention head + num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism. + Must be at least 4. + feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules + cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module + + pos_dim (int): the dimension of each positional-encoding vector prior to projection, + e.g. 128. + + dropout (float): dropout rate + warmup_batches (float): number of batches to warm up over; this controls + dropout of encoder layers. + causal (bool): if True, support chunkwise causal convolution. This should + not hurt WER as no modeling power is lost, but the convolution modules will be + slightly slower and use more memory. Enables use of the chunk_size and + left_context_chunks options in forward(), which simulates streaming + decoding. + chunk_size: (list of int): only set this to other than [-1] if causal; + the chunk size will be randomly chosen from this list. -1 means no chunking. + left_context_frames: (list of int): determines the number of left- + context chunks for causal training; will be rounded to a number of + chunks. Must not be less than cnn_module_kernel (after factoring in + rounding and downsampling); an error will be thrown if this is violated. + """ + def __init__( + self, + output_downsampling_factor: int = 2, + downsampling_factor: Tuple[int] = (2, 4), + encoder_dim: Union[int, Tuple[int]] = 384, + num_encoder_layers: Union[int, Tuple[int]] = 4, + encoder_unmasked_dim: Union[int, Tuple[int]] = 256, + query_head_dim: Union[int, Tuple[int]] = 24, + pos_head_dim: Union[int, Tuple[int]] = 4, + value_head_dim: Union[int, Tuple[int]] = 12, + num_heads: Union[int, Tuple[int]] = 8, + feedforward_dim: Union[int, Tuple[int]] = 1536, + cnn_module_kernel: Union[int, Tuple[int]] = 31, + pos_dim: int = 192, + dropout: FloatLike = None, # see code below for default + warmup_batches: float = 4000.0, + causal: bool = False, + chunk_size: Tuple[int] = [-1], + left_context_frames: Tuple[int] = [-1], + ) -> None: + super(Zipformer2, self).__init__() + + if dropout is None: + dropout = ScheduledFloat((0.0, 0.3), + (20000.0, 0.1)) + + def _to_tuple(x): + """ Converts a single int or a 1-tuple of an int to a tuple with the same length + as downsampling_factor""" + if isinstance(x, int): + x = (x,) + if len(x) == 1: + x = x * len(downsampling_factor) + else: + assert len(x) == len(downsampling_factor) and isinstance(x[0], int) + return x + + self.output_downsampling_factor = output_downsampling_factor # int + self.downsampling_factor = downsampling_factor # tuple + self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple + self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple(encoder_unmasked_dim) # tuple + num_encoder_layers = _to_tuple(num_encoder_layers) + self.num_encoder_layers = num_encoder_layers + self.query_head_dim = query_head_dim = _to_tuple(query_head_dim) + self.value_head_dim = value_head_dim = _to_tuple(value_head_dim) + pos_head_dim = _to_tuple(pos_head_dim) + self.num_heads = num_heads = _to_tuple(num_heads) + feedforward_dim = _to_tuple(feedforward_dim) + self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) + + self.causal = causal + self.chunk_size = chunk_size + self.left_context_frames = left_context_frames + + for u,d in zip(encoder_unmasked_dim, encoder_dim): + assert u <= d + + # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder + encoders = [] + + num_encoders = len(downsampling_factor) + for i in range(num_encoders): + + encoder_layer = Zipformer2EncoderLayer( + embed_dim=encoder_dim[i], + pos_dim=pos_dim, + num_heads=num_heads[i], + query_head_dim=query_head_dim[i], + pos_head_dim=pos_head_dim[i], + value_head_dim=value_head_dim[i], + feedforward_dim=feedforward_dim[i], + dropout=dropout, + cnn_module_kernel=cnn_module_kernel[i], + causal=causal, + ) + + # For the segment of the warmup period, we let the Conv2dSubsampling + # layer learn something. Then we start to warm up the other encoders. + encoder = Zipformer2Encoder( + encoder_layer, + num_encoder_layers[i], + pos_dim=pos_dim, + dropout=dropout, + warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), + warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), + final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5), + ) + + if downsampling_factor[i] != 1: + encoder = DownsampledZipformer2Encoder( + encoder, + dim=encoder_dim[i], + downsample=downsampling_factor[i], + dropout=dropout, + ) + + encoders.append(encoder) + + self.encoders = nn.ModuleList(encoders) + + self.downsample_output = SimpleDownsample(max(encoder_dim), + downsample=output_downsampling_factor, + dropout=dropout) + + def get_feature_masks( + self, + x: Tensor) -> Union[List[float], List[Tensor]]: + """ + In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of + randomized feature masks, one per encoder. + On e.g. 15% of frames, these masks will zero out all enocder dims larger than + some supplied number, e.g. >256, so in effect on those frames we are using + a smaller encoer dim. + + We generate the random masks at this level because we want the 2 masks to 'agree' + all the way up the encoder stack. This will mean that the 1st mask will have + mask values repeated self.zipformer_subsampling_factor times. + + Args: + x: the embeddings (needed for the shape and dtype and device), of shape + (1, batch_size, encoder_dims0) + """ + num_encoders = len(self.encoder_dim) + if not self.training: + return [ 1.0 ] * num_encoders + + (num_frames0, batch_size, _encoder_dims0) = x.shape + + assert self.encoder_dim[0] == _encoder_dims0 + + feature_mask_dropout_prob = 0.125 + + # mask1 shape: (1, batch_size, 1) + mask1 = (torch.rand(1, batch_size, 1, + device=x.device) > + feature_mask_dropout_prob).to(x.dtype) + + # mask2 has additional sequences masked, about twice the number. + mask2 = torch.logical_and(mask1, + (torch.rand(1, batch_size, 1, + device=x.device) > + feature_mask_dropout_prob).to(x.dtype)) + + # dim: (1, batch_size, 2) + mask = torch.cat((mask1, mask2), dim=-1) + + feature_masks = [] + for i in range(num_encoders): + channels = self.encoder_dim[i] + feature_mask = torch.ones(1, batch_size, channels, + dtype=x.dtype, device=x.device) + u1 = self.encoder_unmasked_dim[i] + u2 = u1 + (channels - u1) // 2 + + feature_mask[:, :, u1:u2] *= mask[..., 0:1] + feature_mask[:, :, u2:] *= mask[..., 1:2] + + feature_masks.append(feature_mask) + + return feature_masks + + def get_chunk_info(self) -> Tuple[int, int]: + """ + Returns chunk_size and left_context_chunks. + """ + if not self.causal: + return -1, -1 + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + assert len(self.chunk_size) == 1, self.chunk_size + chunk_size = self.chunk_size[0] + else: + chunk_size = random.choice(self.chunk_size) + + if chunk_size == -1: + left_context_chunks = -1 + else: + if torch.jit.is_scripting() or torch.jit.is_tracing(): + assert len(self.left_context_frames) == 1, self.left_context_frames + left_context_frames = self.left_context_frames[0] + else: + left_context_frames = random.choice(self.left_context_frames) + # Note: in Python, -1 // n == -1 for n > 0 + left_context_chunks = left_context_frames // chunk_size + if left_context_chunks == 0: + left_context_chunks = 1 + + return chunk_size, left_context_chunks + + def forward( + self, x: Tensor, + x_lens: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + x: + The input tensor. Its shape is (seq_len, batch_size, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + src_key_padding_mask: + The mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + Returns: + Return a tuple containing 2 tensors: + - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + """ + outputs = [] + if torch.jit.is_scripting() or torch.jit.is_tracing(): + feature_masks = [1.0] * len(self.encoder_dim) + else: + feature_masks = self.get_feature_masks(x) + + chunk_size, left_context_chunks = self.get_chunk_info() + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + # Not support exporting a model for simulating streaming decoding + attn_mask = None + else: + attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks) + + for i, module in enumerate(self.encoders): + ds = self.downsampling_factor[i] + x = convert_num_channels(x, self.encoder_dim[i]) + + x = module(x, + chunk_size=chunk_size, + feature_mask=feature_masks[i], + src_key_padding_mask=(None if src_key_padding_mask is None + else src_key_padding_mask[...,::ds]), + attn_mask=attn_mask) + outputs.append(x) + + # if the last output has the largest dimension, x will be unchanged, + # it will be the same as outputs[-1]. Otherwise it will be concatenated + # from different pieces of 'outputs', taking each dimension from the + # most recent output that has it present. + x = self._get_full_dim_output(outputs) + x = self.downsample_output(x) + # class Downsample has this rounding behavior.. + assert self.output_downsampling_factor == 2 + if torch.jit.is_scripting() or torch.jit.is_tracing(): + lengths = (x_lens + 1) // 2 + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + lengths = (x_lens + 1) // 2 + + return x, lengths + + def _get_attn_mask( + self, x: Tensor, + chunk_size: int, + left_context_chunks: int + ) -> Optional[Tensor]: + """ + Return None if chunk_size == -1, else return attention mask of shape + (seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len). True + means a masked position. + Args: + x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim). + chunk_size: chunk size, must divide + """ + if chunk_size <= 0: + return None + assert all(chunk_size % d == 0 for d in self.downsampling_factor) + if left_context_chunks >= 0: + num_encoders = len(self.encoder_dim) + assert all (chunk_size * left_context_chunks >= + (self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i] + for i in range(num_encoders)) + else: + left_context_chunks = 1000000 + + seq_len = x.shape[0] + + # t is frame index, shape (seq_len,) + t = torch.arange(seq_len, dtype=torch.int32, device=x.device) + # c is chunk index for each frame, shape (seq_len,) + if torch.jit.is_scripting() or torch.jit.is_tracing(): + c = t // chunk_size + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + c = t // chunk_size + src_c = c + tgt_c = c.unsqueeze(-1) + + attn_mask = torch.logical_or(src_c > tgt_c, + src_c < tgt_c - left_context_chunks) + if __name__ == "__main__": + logging.info(f"attn_mask = {attn_mask}") + return attn_mask + + def _get_full_dim_output(self, outputs: List[Tensor]): + num_encoders = len(self.encoder_dim) + assert len(outputs) == num_encoders + output_dim = max(self.encoder_dim) + output_pieces = [ outputs[-1] ] + cur_dim = self.encoder_dim[-1] + for i in range(num_encoders - 2, -1, -1): + d = self.encoder_dim[i] + if d > cur_dim: + this_output = outputs[i] + output_pieces.append(this_output[..., cur_dim:d]) + cur_dim = d + assert cur_dim == output_dim + return torch.cat(output_pieces, dim=-1) + + def streaming_forward( + self, + x: Tensor, + x_lens: Tensor, + states: List[Tensor], + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Args: + x: + The input tensor. Its shape is (seq_len, batch_size, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + states: list of cached tensors of all encoder layers. For layer-i, + states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, + cached_conv1, cached_conv2). + src_key_padding_mask: + The mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + Returns: + Return a tuple containing 2 tensors: + - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + - updated states + """ + outputs = [] + new_states = [] + layer_offset = 0 + + for i, module in enumerate(self.encoders): + num_layers = module.num_layers + ds = self.downsampling_factor[i] + x = convert_num_channels(x, self.encoder_dim[i]) + + x, new_layer_states = module.streaming_forward( + x, + states=states[layer_offset * 6 : (layer_offset + num_layers) * 6], + left_context_len=self.left_context_frames[0] // ds, + src_key_padding_mask=src_key_padding_mask[..., ::ds], + ) + layer_offset += num_layers + outputs.append(x) + new_states += new_layer_states + + # if the last output has the largest dimension, x will be unchanged, + # it will be the same as outputs[-1]. Otherwise it will be concatenated + # from different pieces of 'outputs', taking each dimension from the + # most recent output that has it present. + x = self._get_full_dim_output(outputs) + x = self.downsample_output(x) + # class Downsample has this rounding behavior.. + assert self.output_downsampling_factor == 2 + if torch.jit.is_scripting() or torch.jit.is_tracing(): + lengths = (x_lens + 1) // 2 + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + lengths = (x_lens + 1) // 2 + + return x, lengths, new_states + + @torch.jit.export + def get_init_states( + self, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), + ) -> List[Tensor]: + """Get initial states. + + A list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + """ + states = [] + for i, module in enumerate(self.encoders): + num_layers = module.num_layers + embed_dim = self.encoder_dim[i] + ds = self.downsampling_factor[i] + num_heads = self.num_heads[i] + key_dim = self.query_head_dim[i] * num_heads + value_dim = self.value_head_dim[i] * num_heads + downsample_left = self.left_context_frames[0] // ds + nonlin_attn_head_dim = 3 * embed_dim // 4 + conv_left_pad = self.cnn_module_kernel[i] // 2 + for layer in range(num_layers): + cached_key = torch.zeros(downsample_left, batch_size, key_dim).to(device) + cached_nonlin_attn = torch.zeros(1, batch_size, downsample_left, nonlin_attn_head_dim).to(device) + cached_val1 = torch.zeros(downsample_left, batch_size, value_dim).to(device) + cached_val2 = torch.zeros(downsample_left, batch_size, value_dim).to(device) + cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).to(device) + cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).to(device) + states += [cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2] + + return states + + +def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: + return ScheduledFloat((0.0, x), + (20000.0, ratio * x), + default=x) + + +def _balancer_schedule(min_prob: float): + return ScheduledFloat((0.0, 0.4), (8000.0, min_prob)) + + +class Zipformer2EncoderLayer(nn.Module): + """ + Args: + embed_dim: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + feedforward_dim: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + cnn_module_kernel (int): Kernel size of convolution module. + + Examples:: + >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + def __init__( + self, + embed_dim: int, + pos_dim: int, + num_heads: int, + query_head_dim: int, + pos_head_dim: int, + value_head_dim: int, + feedforward_dim: int, + dropout: FloatLike = 0.1, + cnn_module_kernel: int = 31, + causal: bool = False, + attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), + conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), + const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0), + ff2_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)), + ff3_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)), + bypass_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.02), default=0), + ) -> None: + super(Zipformer2EncoderLayer, self).__init__() + self.embed_dim = embed_dim + + # self.bypass implements layer skipping as well as bypass; see its default values. + self.bypass = BypassModule(embed_dim, skip_rate=bypass_skip_rate, + straight_through_rate=0) + # bypass_mid is bypass used in the middle of the layer. + self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0) + + # skip probability for dynamic modules (meaning: anything but feedforward). + self.attention_skip_rate = copy.deepcopy(attention_skip_rate) + # an additional skip probability that applies to ConvModule to stop it from + # contributing too much early on. + self.conv_skip_rate = copy.deepcopy(conv_skip_rate) + + # ff2_skip_rate is to prevent the ff2 module from having output that's too big + # compared to its residual. + self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate) + self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate) + + self.const_attention_rate = copy.deepcopy(const_attention_rate) + + self.self_attn_weights = RelPositionMultiheadAttentionWeights( + embed_dim, pos_dim=pos_dim, num_heads=num_heads, + query_head_dim=query_head_dim, pos_head_dim=pos_head_dim, + dropout=0.0, + ) + + self.self_attn1 = SelfAttention(embed_dim, num_heads, + value_head_dim) + + self.self_attn2 = SelfAttention(embed_dim, num_heads, + value_head_dim) + + self.feed_forward1 = FeedforwardModule(embed_dim, + (feedforward_dim * 3) // 4, + dropout) + + self.feed_forward2 = FeedforwardModule(embed_dim, + feedforward_dim, + dropout) + + self.feed_forward3 = FeedforwardModule(embed_dim, + (feedforward_dim * 5) // 4, + dropout) + + self.nonlin_attention = NonlinAttention(embed_dim, + hidden_channels=3 * embed_dim // 4) + + self.conv_module1 = ConvolutionModule(embed_dim, + cnn_module_kernel, + causal=causal) + + self.conv_module2 = ConvolutionModule(embed_dim, + cnn_module_kernel, + causal=causal) + + # TODO: remove it + self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) + + self.norm = BiasNorm(embed_dim) + + self.balancer1 = Balancer( + embed_dim, channel_dim=-1, + min_positive=0.45, max_positive=0.55, + min_abs=0.2, max_abs=4.0, + ) + + # balancer for output of NonlinAttentionModule + self.balancer_na = Balancer( + embed_dim, channel_dim=-1, + min_positive=0.3, max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)), + prob=0.05, # out of concern for memory usage + ) + + # balancer for output of feedforward2, prevent it from staying too + # small. give this a very small probability, even at the start of + # training, it's to fix a rare problem and it's OK to fix it slowly. + self.balancer_ff2 = Balancer( + embed_dim, channel_dim=-1, + min_positive=0.3, max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0), + max_abs=2.0, + prob=0.05, + ) + + self.balancer_ff3 = Balancer( + embed_dim, channel_dim=-1, + min_positive=0.3, max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0), + max_abs=4.0, + prob=0.05, + ) + + self.whiten = Whiten(num_groups=1, + whitening_limit=_whitening_schedule(4.0, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01) + + self.balancer2 = Balancer( + embed_dim, channel_dim=-1, + min_positive=0.45, max_positive=0.55, + min_abs=0.1, max_abs=4.0, + ) + + def get_sequence_dropout_mask(self, x: Tensor, dropout_rate: float) -> Optional[Tensor]: + if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting() or torch.jit.is_tracing(): + return None + batch_size = x.shape[1] + mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype) + return mask + + def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor: + """ + Apply sequence-level dropout to x. + x shape: (seq_len, batch_size, embed_dim) + """ + dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate) + if dropout_mask is None: + return x + else: + return x * dropout_mask + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + chunk_size: int = -1, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Pass the input through the encoder layer. + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim) + chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + + Returns: + A tensor which has the same shape as src + """ + src_orig = src + + # dropout rate for non-feedforward submodules + if torch.jit.is_scripting() or torch.jit.is_tracing(): + attention_skip_rate = 0.0 + else: + attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0 + + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + attn_weights = self.self_attn_weights( + src, + pos_emb=pos_emb, + attn_mask=attn_mask, + key_padding_mask=src_key_padding_mask, + ) + + src = src + self.feed_forward1(src) + + self_attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate) + + selected_attn_weights = attn_weights[0:1] + if torch.jit.is_scripting() or torch.jit.is_tracing(): + pass + elif not self.training and random.random() < float(self.const_attention_rate): + # Make attention weights constant. The intention is to + # encourage these modules to do something similar to an + # averaging-over-time operation. + # only need the mask, can just use the 1st one and expand later + selected_attn_weights = selected_attn_weights[0:1] + selected_attn_weights = (selected_attn_weights > 0.0).to(selected_attn_weights.dtype) + selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True)) + + na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights)) + + src = src + (na if self_attn_dropout_mask is None else na * self_attn_dropout_mask) + + self_attn = self.self_attn1(src, attn_weights) + + src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + conv_skip_rate = 0.0 + else: + conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 + src = src + self.sequence_dropout(self.conv_module1(src, chunk_size=chunk_size, + src_key_padding_mask=src_key_padding_mask), + conv_skip_rate) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + ff2_skip_rate = 0.0 + else: + ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0 + src = src + self.sequence_dropout(self.balancer_ff2(self.feed_forward2(src)), + ff2_skip_rate) + + # bypass in the middle of the layer. + src = self.bypass_mid(src_orig, src) + + self_attn = self.self_attn2(src, attn_weights) + + src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + conv_skip_rate = 0.0 + else: + conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 + src = src + self.sequence_dropout(self.conv_module2(src, chunk_size=chunk_size, + src_key_padding_mask=src_key_padding_mask), + conv_skip_rate) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + ff3_skip_rate = 0.0 + else: + ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0 + src = src + self.sequence_dropout(self.balancer_ff3(self.feed_forward3(src)), + ff3_skip_rate) + + src = self.balancer1(src) + src = self.norm(src) + + src = self.bypass(src_orig, src) + + src = self.balancer2(src) + src = self.whiten(src) + + return src + + def streaming_forward( + self, + src: Tensor, + pos_emb: Tensor, + cached_key: Tensor, + cached_nonlin_attn: Tensor, + cached_val1: Tensor, + cached_val2: Tensor, + cached_conv1: Tensor, + cached_conv2: Tensor, + left_context_len: int, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + """Pass the input through the encoder layer in streaming forward mode. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + pos_emb: (1, left_context_len+2*seq_len-1, pos_emb_dim) or + (batch_size, left_context_len+2*seq_len-1, pos_emb_dim) + cached_key: cached attention key tensor of left context, + of shape (left_context_len, batch_size, key_dim) + cached_nonlin_attn: left context for nonlin_attention module, a Tensor of shape + (num_heads, batch_size, left_context_len, head_dim) + cached_val1: cached left context for the first attention module, + of shape (left_context_len, batch_size, value_dim) + cached_val2: cached left context for the second attention module, + of shape (left_context_len, batch_size, value_dim) + cached_conv1: cached left context for the first convolution module, + of shape (batch_size, channels, left_pad) + cached_conv2: cached left context for the second convolution module, + of shape (batch_size, channels, left_pad) + left_context_len: number of left context frames. + src_key_padding_mask: the mask for padding, of shape + (batch_size, left_context_len + seq_len); True means masked position. + May be None. + + Returns: + - x, with the same shape as src + - updated cached_key + - updated cached_nonlin_attn + - updated cached_val1 + - updated cached_val2 + - updated cached_conv1 + - updated cached_conv2 + """ + src_orig = src + + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + attn_weights, cached_key = self.self_attn_weights.streaming_forward( + src, + pos_emb=pos_emb, + cached_key=cached_key, + left_context_len=left_context_len, + key_padding_mask=src_key_padding_mask, + ) + + src = src + self.feed_forward1(src) + + na, cached_nonlin_attn = self.nonlin_attention.streaming_forward( + src, + attn_weights[0:1], + cached_x=cached_nonlin_attn, + left_context_len=left_context_len, + ) + src = src + na + + self_attn, cached_val1 = self.self_attn1.streaming_forward( + src, + attn_weights=attn_weights, + cached_val=cached_val1, + left_context_len=left_context_len, + ) + src = src + self_attn + + src_conv, cached_conv1 = self.conv_module1.streaming_forward( + src, + cache=cached_conv1, + src_key_padding_mask=src_key_padding_mask[:, left_context_len:], + ) + src = src + src_conv + + src = src + self.feed_forward2(src) + + # bypass in the middle of the layer. + src = self.bypass_mid(src_orig, src) + + self_attn, cached_val2 = self.self_attn2.streaming_forward( + src, + attn_weights=attn_weights, + cached_val=cached_val2, + left_context_len=left_context_len, + ) + src = src + self_attn + + src_conv, cached_conv2 = self.conv_module2.streaming_forward( + src, + cache=cached_conv2, + src_key_padding_mask=src_key_padding_mask[:, left_context_len:], + ) + src = src + src_conv + + src = src + self.feed_forward3(src) + + src = self.norm(src) + + src = self.bypass(src_orig, src) + + return ( + src, + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ) + + +class Zipformer2Encoder(nn.Module): + r"""Zipformer2Encoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the Zipformer2EncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + pos_dim: the dimension for the relative positional encoding + + Examples:: + >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) + >>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = zipformer_encoder(src) + """ + def __init__( + self, + encoder_layer: nn.Module, + num_layers: int, + pos_dim: int, + dropout: float, + warmup_begin: float, + warmup_end: float, + initial_layerdrop_rate: float = 0.5, + final_layerdrop_rate: float = 0.05, + ) -> None: + super().__init__() + self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.15, + length_factor=1.0) + + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + assert 0 <= warmup_begin <= warmup_end + + delta = (1. / num_layers) * (warmup_end - warmup_begin) + cur_begin = warmup_begin # interpreted as a training batch index + for i in range(num_layers): + cur_end = cur_begin + delta + self.layers[i].bypass.skip_rate = ScheduledFloat((cur_begin, initial_layerdrop_rate), + (cur_end, final_layerdrop_rate), + default=0.0) + cur_begin = cur_end + + def forward( + self, + src: Tensor, + chunk_size: int = -1, + feature_mask: Union[Tensor, float] = 1.0, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + + Returns: a Tensor with the same shape as src. + """ + pos_emb = self.encoder_pos(src) + output = src + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + output = output * feature_mask + + for i, mod in enumerate(self.layers): + output = mod( + output, + pos_emb, + chunk_size=chunk_size, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + ) + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + output = output * feature_mask + + return output + + def streaming_forward( + self, + src: Tensor, + states: List[Tensor], + left_context_len: int, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, List[Tensor]]: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is + (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + left_context_len: Number of left context frames. + src_key_padding_mask: the mask for padding, of shape + (batch_size, left_context_len + seq_len); True means masked position. + May be None. + + Returns: + - output, a Tensor with the same shape as src. + - updated states + """ + pos_emb = self.encoder_pos(src, left_context_len) + output = src + + new_states = [] + for i, mod in enumerate(self.layers): + ( + cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2 + ) = states[i * 6: (i + 1) * 6] + ( + output, + new_cached_key, + new_cached_nonlin_attn, + new_cached_val1, + new_cached_val2, + new_cached_conv1, + new_cached_conv2 + ) = mod.streaming_forward( + output, + pos_emb, + cached_key=cached_key, + cached_nonlin_attn=cached_nonlin_attn, + cached_val1=cached_val1, + cached_val2=cached_val2, + cached_conv1=cached_conv1, + cached_conv2=cached_conv2, + left_context_len=left_context_len, + src_key_padding_mask=src_key_padding_mask, + ) + new_states += [ + new_cached_key, + new_cached_nonlin_attn, + new_cached_val1, + new_cached_val2, + new_cached_conv1, + new_cached_conv2, + ] + + return output, new_states + + +class BypassModule(nn.Module): + """ + An nn.Module that implements a learnable bypass scale, and also randomized per-sequence + layer-skipping. The bypass is limited during early stages of training to be close to + "straight-through", i.e. to not do the bypass operation much initially, in order to + force all the modules to learn something. + """ + def __init__( + self, + embed_dim: int, + skip_rate: FloatLike = 0.0, + straight_through_rate: FloatLike = 0.0, + scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0), + scale_max: FloatLike = 1.0): + super().__init__() + self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) + self.skip_rate = copy.deepcopy(skip_rate) + self.straight_through_rate = copy.deepcopy(straight_through_rate) + self.scale_min = copy.deepcopy(scale_min) + self.scale_max = copy.deepcopy(scale_max) + + def _get_bypass_scale(self, batch_size: int): + # returns bypass-scale of shape (num_channels,), + # or (batch_size, num_channels,). This is actually the + # scale on the non-residual term, so 0 correponds to bypassing + # this module. + if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: + return self.bypass_scale + else: + ans = limit_param_value(self.bypass_scale, + min=float(self.scale_min), + max=float(self.scale_max)) + skip_rate = float(self.skip_rate) + if skip_rate != 0.0: + mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate + ans = ans * mask + # now ans is of shape (batch_size, num_channels), and is zero for sequences + # on which we have randomly chosen to do layer-skipping. + straight_through_rate = float(self.straight_through_rate) + if straight_through_rate != 0.0: + mask = torch.rand((batch_size, 1), device=ans.device) < straight_through_rate + ans = torch.maximum(ans, mask.to(ans.dtype)) + return ans + + def forward(self, + src_orig: Tensor, + src: Tensor): + """ + Args: src_orig and src are both of shape (seq_len, batch_size, num_channels) + Returns: something with the same shape as src and src_orig + """ + bypass_scale = self._get_bypass_scale(src.shape[1]) + return src_orig + (src - src_orig) * bypass_scale + + +class DownsampledZipformer2Encoder(nn.Module): + r""" + DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame rate, + after convolutional downsampling, and then upsampled again at the output, and combined + with the origin input, so that the output has the same shape as the input. + """ + def __init__(self, + encoder: nn.Module, + dim: int, + downsample: int, + dropout: FloatLike): + super(DownsampledZipformer2Encoder, self).__init__() + self.downsample_factor = downsample + self.downsample = SimpleDownsample(dim, + downsample, dropout) + self.num_layers = encoder.num_layers + self.encoder = encoder + self.upsample = SimpleUpsample(dim, downsample) + self.out_combiner = BypassModule(dim, straight_through_rate=0) + + def forward( + self, + src: Tensor, + chunk_size: int = -1, + feature_mask: Union[Tensor, float] = 1.0, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Downsample, go through encoder, upsample. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + + Returns: a Tensor with the same shape as src. + """ + src_orig = src + src = self.downsample(src) + ds = self.downsample_factor + if attn_mask is not None: + attn_mask = attn_mask[::ds,::ds] + + src = self.encoder( + src, + chunk_size=chunk_size // ds, + feature_mask=feature_mask, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + ) + src = self.upsample(src) + # remove any extra frames that are not a multiple of downsample_factor + src = src[:src_orig.shape[0]] + + return self.out_combiner(src_orig, src) + + def streaming_forward( + self, + src: Tensor, + states: List[Tensor], + left_context_len: int, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, List[Tensor]]: + r"""Downsample, go through encoder, upsample, in streaming forward mode. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is + (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + left_context_len: Number of left context frames. + src_key_padding_mask: the mask for padding, of shape (batch_size, left_context_len+seq_len); + True means masked position. May be None. + + Returns: + - output, a Tensor with the same shape as src. + - updated states + """ + src_orig = src + src = self.downsample(src) + + src, new_states = self.encoder.streaming_forward( + src, + states=states, + left_context_len=left_context_len, + src_key_padding_mask=src_key_padding_mask, + ) + src = self.upsample(src) + # remove any extra frames that are not a multiple of downsample_factor + src = src[:src_orig.shape[0]] + + return self.out_combiner(src_orig, src), new_states + + +class SimpleDownsample(torch.nn.Module): + """ + Does downsampling with attention, by weighted sum, and a projection.. + """ + def __init__(self, + channels: int, + downsample: int, + dropout: FloatLike): + super(SimpleDownsample, self).__init__() + + self.bias = nn.Parameter(torch.zeros(downsample)) + + self.name = None # will be set from training code + self.dropout = copy.deepcopy(dropout) + + self.downsample = downsample + + def forward(self, + src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, in_channels) + Returns a tensor of shape + ( (seq_len+downsample-1)//downsample, batch_size, channels) + """ + (seq_len, batch_size, in_channels) = src.shape + ds = self.downsample + d_seq_len = (seq_len + ds - 1) // ds + + # Pad to an exact multiple of self.downsample + # right-pad src, repeating the last element. + pad = d_seq_len * ds - seq_len + src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2]) + src = torch.cat((src, src_extra), dim=0) + assert src.shape[0] == d_seq_len * ds + + src = src.reshape(d_seq_len, ds, batch_size, in_channels) + + weights = self.bias.softmax(dim=0) + # weights: (downsample, 1, 1) + weights = weights.unsqueeze(-1).unsqueeze(-1) + + # ans1 is the first `in_channels` channels of the output + ans = (src * weights).sum(dim=1) + + return ans + + +class SimpleUpsample(torch.nn.Module): + """ + A very simple form of upsampling that mostly just repeats the input, but + also adds a position-specific bias. + """ + def __init__(self, + num_channels: int, + upsample: int): + super(SimpleUpsample, self).__init__() + self.upsample = upsample + + def forward(self, + src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, num_channels) + Returns a tensor of shape + ( (seq_len*upsample), batch_size, num_channels) + """ + upsample = self.upsample + (seq_len, batch_size, num_channels) = src.shape + src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels) + src = src.reshape(seq_len * upsample, batch_size, num_channels) + return src + + +class CompactRelPositionalEncoding(torch.nn.Module): + """ + Relative positional encoding module. This version is "compact" meaning it is able to encode + the important information about the relative position in a relatively small number of dimensions. + The goal is to make it so that small differences between large relative offsets (e.g. 1000 vs. 1001) + make very little difference to the embedding. Such differences were potentially important + when encoding absolute position, but not important when encoding relative position because there + is now no need to compare two large offsets with each other. + + Our embedding works done by projecting the interval [-infinity,infinity] to a finite interval + using the atan() function, before doing the fourier transform of that fixed interval. The + atan() function would compress the "long tails" too small, + making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic + function to compress large offsets to a smaller range before applying atan(). + Scalings are chosen in such a way that the embedding can clearly distinguish invidual offsets as long + as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embedding_dim) + + + Args: + embed_dim: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length: just a heuristic for initialization. + length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives + less weight to small differences of offset near the origin. + """ + def __init__( + self, embed_dim: int, + dropout_rate: FloatLike, + max_len: int = 1000, + length_factor: float = 1.0, + ) -> None: + """Construct a CompactRelPositionalEncoding object.""" + super(CompactRelPositionalEncoding, self).__init__() + self.embed_dim = embed_dim + assert embed_dim % 2 == 0 + self.dropout = Dropout2(dropout_rate) + self.pe = None + assert length_factor >= 1.0 + self.length_factor = length_factor + self.extend_pe(torch.tensor(0.0).expand(max_len)) + + def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None: + """Reset the positional encodings.""" + T = x.size(0) + left_context_len + + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(0) >= T * 2 - 1: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + + # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ] + x = torch.arange(-(T-1), T, + device=x.device).to(torch.float32).unsqueeze(1) + + freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device) + + # `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution + # for small time offsets but less resolution for large time offsets. + compression_length = (self.embed_dim ** 0.5) + # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity; + # but it does so more slowly than T for large absolute values of T. + # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which + # is important. + x_compressed = compression_length * x.sign() * ((x.abs() + compression_length).log() - math.log(compression_length)) + + # if self.length_factor == 1.0, then length_scale is chosen so that the + # FFT can exactly separate points close to the origin (T == 0). So this + # part of the formulation is not really heuristic. + # But empirically, for ASR at least, length_factor > 1.0 seems to work better. + length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi) + + # note for machine implementations: if atan is not available, we can use: + # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2) + # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x)) + x_atan = (x_compressed / length_scale).atan() # results between -pi and pi + + cosines = (x_atan * freqs).cos() + sines = (x_atan * freqs).sin() + + pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device) + pe[:, 0::2] = cosines + pe[:, 1::2] = sines + pe[:, -1] = 1.0 # for bias. + + self.pe = pe.to(dtype=x.dtype) + + def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor: + """Create positional encoding. + + Args: + x (Tensor): Input tensor (time, batch, `*`). + left_context_len: (int): Length of cached left context. + + Returns: + positional embedding, of shape (batch, left_context_len + 2*time-1, `*`). + """ + self.extend_pe(x, left_context_len) + x_size_left = x.size(0) + left_context_len + # length of positive side: x.size(0) + left_context_len + # length of negative side: x.size(0) + pos_emb = self.pe[ + self.pe.size(0) // 2 + - x_size_left + + 1 : self.pe.size(0) // 2 # noqa E203 + + x.size(0), + : + ] + pos_emb = pos_emb.unsqueeze(0) + return self.dropout(pos_emb) + + +class RelPositionMultiheadAttentionWeights(nn.Module): + r"""Module that computes multi-head attention weights with relative position encoding. + Various other modules consume the resulting attention weights: see, for example, the + SimpleAttention module which allows you to compute conventional attention. + + This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", + we have to write up the differences. + + + Args: + embed_dim: number of channels at the input to this module, e.g. 256 + pos_dim: dimension of the positional encoding vectors, e.g. 128. + num_heads: number of heads to compute weights for, e.g. 8 + query_head_dim: dimension of the query (and key), per head. e.g. 24. + pos_head_dim: dimension of the projected positional encoding per head, e.g. 4. + dropout: dropout probability for attn_output_weights. Default: 0.0. + pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on + any given call to forward(), in training time. + """ + + def __init__( + self, + embed_dim: int, + pos_dim: int, + num_heads: int, + query_head_dim: int, + pos_head_dim: int, + dropout: float = 0.0, + pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), + (4000.0, 0.0)) + ) -> None: + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.query_head_dim = query_head_dim + self.pos_head_dim = pos_head_dim + self.dropout = dropout + self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate) + self.name = None # will be overwritten in training code; for diagnostics. + + key_head_dim = query_head_dim + in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads + + # the initial_scale is supposed to take over the "scaling" factor of + # head_dim ** -0.5 that has been used in previous forms of attention, + # dividing it between the query and key. Note: this module is intended + # to be used with the ScaledAdam optimizer; with most other optimizers, + # it would be necessary to apply the scaling factor in the forward function. + self.in_proj = ScaledLinear(embed_dim, in_proj_dim, bias=True, + initial_scale=query_head_dim**-0.25) + + self.whiten_keys = Whiten(num_groups=num_heads, + whitening_limit=_whitening_schedule(3.0), + prob=(0.025, 0.25), + grad_scale=0.025) + + # add a balancer for the keys that runs with very small probability, and + # tries to enforce that all dimensions have mean around zero. The + # weights produced by this module are invariant to adding a constant to + # the keys, so the derivative of the bias is mathematically zero; but + # due to how Adam/ScaledAdam work, it can learn a fairly large nonzero + # bias because the small numerical roundoff tends to have a non-random + # sign. This module is intended to prevent that. Use a very small + # probability; that should be suffixient to fix the problem. + self.balance_keys = Balancer(key_head_dim * num_heads, + channel_dim=-1, + min_positive=0.4, + max_positive=0.6, + min_abs=0.0, + max_abs=100.0, + prob=0.025) + + # linear transformation for positional encoding. + self.linear_pos = ScaledLinear(pos_dim, + num_heads * pos_head_dim, + bias=False, + initial_scale=0.05) + + # the following are for diagnosics only, see --print-diagnostics option + self.copy_pos_query = Identity() + self.copy_query = Identity() + + def forward( + self, + x: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + ) -> Tensor: + r""" + Args: + x: input of shape (seq_len, batch_size, embed_dim) + pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim) + key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that + are True in this mask will be ignored as sources in the attention weighting. + attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len), + interpreted as ([batch_size,] tgt_seq_len, src_seq_len) + saying which positions are allowed to attend to which other positions. + Returns: + a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len) + interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). + """ + x = self.in_proj(x) + query_head_dim = self.query_head_dim + pos_head_dim = self.pos_head_dim + num_heads = self.num_heads + + seq_len, batch_size, _ = x.shape + + query_dim = query_head_dim * num_heads + + # self-attention + q = x[...,0:query_dim] + k = x[...,query_dim:2*query_dim] + # p is the position-encoding query + p = x[...,2*query_dim:] + assert p.shape[-1] == num_heads * pos_head_dim + + q = self.copy_query(q) # for diagnostics only, does nothing. + k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass. + p = self.copy_pos_query(p) # for diagnostics only, does nothing. + + q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) + p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) + k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) + + # time1 refers to target, time2 refers to source. + q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) + p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) + + attn_scores = torch.matmul(q, k) + + use_pos_scores = False + if torch.jit.is_scripting() or torch.jit.is_tracing(): + # We can't put random.random() in the same line + use_pos_scores = True + elif not self.training or random.random() >= float(self.pos_emb_skip_rate): + use_pos_scores = True + + if use_pos_scores: + pos_emb = self.linear_pos(pos_emb) + seq_len2 = 2 * seq_len - 1 + pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1) + # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) + + # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) + # [where seq_len2 represents relative position.] + pos_scores = torch.matmul(p, pos_emb) + # the following .as_strided() expression converts the last axis of pos_scores from relative + # to absolute position. I don't know whether I might have got the time-offsets backwards or + # not, but let this code define which way round it is supposed to be. + if torch.jit.is_tracing(): + (num_heads, batch_size, time1, n) = pos_scores.shape + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(seq_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + pos_scores = pos_scores.reshape(-1, n) + pos_scores = torch.gather(pos_scores, dim=1, index=indexes) + pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len) + else: + pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, seq_len), + (pos_scores.stride(0), + pos_scores.stride(1), + pos_scores.stride(2)-pos_scores.stride(3), + pos_scores.stride(3)), + storage_offset=pos_scores.stride(3) * (seq_len - 1)) + + attn_scores = attn_scores + pos_scores + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + pass + elif self.training and random.random() < 0.1: + # This is a harder way of limiting the attention scores to not be + # too large. It incurs a penalty if any of them has an absolute + # value greater than 50.0. this should be outside the normal range + # of the attention scores. We use this mechanism instead of, say, + # something added to the loss function involving the entropy, + # because once the entropy gets very small gradients through the + # softmax can become very small, and we'd get zero derivatives. The + # choices of 1.0e-04 as the scale on the penalty makes this + # mechanism vulnerable to the absolute scale of the loss function, + # but we view this as a failsafe to avoid "implausible" parameter + # values rather than a regularization method that should be active + # under normal circumstances. + attn_scores = penalize_abs_values_gt(attn_scores, + limit=25.0, + penalty=1.0e-04, + name=self.name) + + assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) + + if attn_mask is not None: + assert attn_mask.dtype == torch.bool + # use -1000 to avoid nan's where attn_mask and key_padding_mask make + # all scores zero. It's important that this be large enough that exp(-1000) + # is exactly zero, for reasons related to const_attention_rate, it + # compares the final weights with zero. + attn_scores = attn_scores.masked_fill(attn_mask, -1000) + + if key_padding_mask is not None: + assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape + attn_scores = attn_scores.masked_fill( + key_padding_mask.unsqueeze(1), + -1000, + ) + + # We use our own version of softmax, defined in scaling.py, which should + # save a little of the memory used in backprop by, if we are in + # automatic mixed precision mode (amp / autocast), by only storing the + # half-precision output for backprop purposes. + attn_weights = softmax(attn_scores, dim=-1) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + pass + elif random.random() < 0.001 and not self.training: + self._print_attn_entropy(attn_weights) + + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + + return attn_weights + + def streaming_forward( + self, + x: Tensor, + pos_emb: Tensor, + cached_key: Tensor, + left_context_len: int, + key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor]: + r""" + Args: + x: input of shape (seq_len, batch_size, embed_dim) + pos_emb: Positional embedding tensor, of shape (1, left_context_len+2*seq_len-1, pos_dim) + cached_key: cached attention key tensor of left context, + of shape (left_context_len, batch_size, key_dim) + left_context_len: number of left context frames. + key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that + are True in this mask will be ignored as sources in the attention weighting. + + Returns: + - attention weights, of shape (hum_heads, batch_size, seq_len, seq_len2), + interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). + - updated cached attention key tensor of left context. + """ + x = self.in_proj(x) + query_head_dim = self.query_head_dim + pos_head_dim = self.pos_head_dim + num_heads = self.num_heads + + seq_len, batch_size, _ = x.shape + + query_dim = query_head_dim * num_heads + + # self-attention + q = x[...,0:query_dim] + k = x[...,query_dim:2*query_dim] + # p is the position-encoding query + p = x[...,2*query_dim:] + assert p.shape[-1] == num_heads * pos_head_dim + + # Pad cached left contexts + assert cached_key.shape[0] == left_context_len, (cached_key.shape[0], left_context_len) + k = torch.cat([cached_key, k], dim=0) + # Update cached left contexts + cached_key = k[-left_context_len:, ...] + + # The length of key + k_len = k.shape[0] + + q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) + p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) + k = k.reshape(k_len, batch_size, num_heads, query_head_dim) + + # time1 refers to target, time2 refers to source. + q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) + p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) + + attn_scores = torch.matmul(q, k) + + pos_emb = self.linear_pos(pos_emb) + seq_len2 = 2 * seq_len - 1 + left_context_len + pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1) + # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) + + # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) + # [where seq_len2 represents relative position.] + pos_scores = torch.matmul(p, pos_emb) + + if torch.jit.is_tracing(): + (num_heads, batch_size, time1, n) = pos_scores.shape + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(k_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + pos_scores = pos_scores.reshape(-1, n) + pos_scores = torch.gather(pos_scores, dim=1, index=indexes) + pos_scores = pos_scores.reshape(num_heads, batch_size, time1, k_len) + # the following .as_strided() expression converts the last axis of pos_scores from relative + # to absolute position. I don't know whether I might have got the time-offsets backwards or + # not, but let this code define which way round it is supposed to be. + else: + pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, k_len), + (pos_scores.stride(0), + pos_scores.stride(1), + pos_scores.stride(2)-pos_scores.stride(3), + pos_scores.stride(3)), + storage_offset=pos_scores.stride(3) * (seq_len - 1)) + + attn_scores = attn_scores + pos_scores + + assert attn_scores.shape == (num_heads, batch_size, seq_len, k_len), attn_scores.shape + + if key_padding_mask is not None: + assert key_padding_mask.shape == (batch_size, k_len), key_padding_mask.shape + attn_scores = attn_scores.masked_fill( + key_padding_mask.unsqueeze(1), + -1000, + ) + + attn_weights = attn_scores.softmax(dim=-1) + + return attn_weights, cached_key + + def _print_attn_entropy( + self, + attn_weights: Tensor): + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape + + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + attn_weights = attn_weights.to(torch.float32) + attn_weights_entropy = -((attn_weights + 1.0e-20).log() * attn_weights).sum( + dim=-1).mean(dim=(1,2)) + logging.info(f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}") + + +class SelfAttention(nn.Module): + """ + The simplest possible attention module. This one works with already-computed attention + weights, e.g. as computed by RelPositionMultiheadAttentionWeights. + + Args: + embed_dim: the input and output embedding dimension + num_heads: the number of attention heads + value_head_dim: the value dimension per head + """ + def __init__( + self, + embed_dim: int, + num_heads: int, + value_head_dim: int, + ) -> None: + super().__init__() + self.in_proj = nn.Linear(embed_dim, + num_heads * value_head_dim, + bias=True) + + self.out_proj = ScaledLinear(num_heads * value_head_dim, + embed_dim, bias=True, + initial_scale=0.05) + + self.whiten = Whiten(num_groups=1, + whitening_limit=_whitening_schedule(7.5, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01) + + def forward( + self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """ + Args: + x: input tensor, of shape (seq_len, batch_size, embed_dim) + attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), + with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect + attn_weights.sum(dim=-1) == 1. + Returns: + a tensor with the same shape as x. + """ + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) + + x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, value_head_dim) + value_head_dim = x.shape[-1] + + # todo: see whether there is benefit in overriding matmul + x = torch.matmul(attn_weights, x) + # v: (num_heads, batch_size, seq_len, value_head_dim) + + x = x.permute(2, 1, 0, 3).contiguous().view( + seq_len, batch_size, num_heads * value_head_dim) + + # returned value is of shape (seq_len, batch_size, embed_dim), like the input. + x = self.out_proj(x) + x = self.whiten(x) + + return x + + def streaming_forward( + self, + x: Tensor, + attn_weights: Tensor, + cached_val: Tensor, + left_context_len: int, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + x: input tensor, of shape (seq_len, batch_size, embed_dim) + attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), + with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect + attn_weights.sum(dim=-1) == 1. + cached_val: cached attention value tensor of left context, + of shape (left_context_len, batch_size, value_dim) + left_context_len: number of left context frames. + + Returns: + - attention weighted output, a tensor with the same shape as x. + - updated cached attention value tensor of left context. + """ + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + seq_len2 = seq_len + left_context_len + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len2) + + x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) + + # Pad cached left contexts + assert cached_val.shape[0] == left_context_len, (cached_val.shape[0], left_context_len) + x = torch.cat([cached_val, x], dim=0) + # Update cached left contexts + cached_val = x[-left_context_len:, ...] + + x = x.reshape(seq_len2, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, value_head_dim) + value_head_dim = x.shape[-1] + + # todo: see whether there is benefit in overriding matmul + x = torch.matmul(attn_weights, x) + # v: (num_heads, batch_size, seq_len, value_head_dim) + + x = x.permute(2, 1, 0, 3).contiguous().view( + seq_len, batch_size, num_heads * value_head_dim) + + # returned value is of shape (seq_len, batch_size, embed_dim), like the input. + x = self.out_proj(x) + + return x, cached_val + + +class FeedforwardModule(nn.Module): + """Feedforward module in Zipformer2 model. + """ + def __init__(self, + embed_dim: int, + feedforward_dim: int, + dropout: FloatLike): + super(FeedforwardModule, self).__init__() + self.in_proj = nn.Linear(embed_dim, feedforward_dim) + + self.hidden_balancer = Balancer(feedforward_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=1.0, + min_abs=0.75, + max_abs=5.0) + + # shared_dim=0 means we share the dropout mask along the time axis + self.out_proj = ActivationDropoutAndLinear(feedforward_dim, embed_dim, + activation='SwooshL', + dropout_p=dropout, + dropout_shared_dim=0, bias=True, + initial_scale=0.1) + + self.out_whiten = Whiten(num_groups=1, + whitening_limit=_whitening_schedule(7.5), + prob=(0.025, 0.25), + grad_scale=0.01) + + def forward(self, x: Tensor): + x = self.in_proj(x) + x = self.hidden_balancer(x) + # out_proj contains SwooshL activation, then dropout, then linear. + x = self.out_proj(x) + x = self.out_whiten(x) + return x + + +class NonlinAttention(nn.Module): + """This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed + from the attention module) in place of actual convolution. We also took out the second nonlinearity, the + one after the attention mechanism. + + Args: + channels (int): The number of channels of conv layers. + """ + + def __init__( + self, + channels: int, + hidden_channels: int, + ) -> None: + super().__init__() + + self.hidden_channels = hidden_channels + + self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True) + + # balancer that goes before the sigmoid. Have quite a large min_abs value, at 2.0, + # because we noticed that well-trained instances of this module have abs-value before the sigmoid + # starting from about 3, and poorly-trained instances of the module have smaller abs values + # before the sigmoid. + self.balancer = Balancer( + hidden_channels, channel_dim=-1, + min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)), + max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)), + min_abs=0.5, + max_abs=5.0, + ) + self.tanh = nn.Tanh() + + self.identity1 = Identity() # for diagnostics. + self.identity2 = Identity() # for diagnostics. + self.identity3 = Identity() # for diagnostics. + + self.out_proj = ScaledLinear(hidden_channels, channels, + bias=True, + initial_scale=0.05) + + self.whiten1 = Whiten(num_groups=1, + whitening_limit=_whitening_schedule(5.0), + prob=(0.025, 0.25), + grad_scale=0.01) + + self.whiten2 = Whiten(num_groups=1, + whitening_limit=_whitening_schedule(5.0, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01) + + def forward( + self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """. + Args: + x: a Tensor of shape (seq_len, batch_size, num_channels) +attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) + Returns: + a Tensor with the same shape as x + """ + x = self.in_proj(x) + + (seq_len, batch_size, _) = x.shape + hidden_channels = self.hidden_channels + + s, x, y = x.chunk(3, dim=-1) + + # s will go through tanh. + + s = self.balancer(s) + s = self.tanh(s) + + s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) + x = self.whiten1(x) + x = x * s + x = self.identity1(x) # diagnostics only, it's the identity. + + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) + + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = torch.matmul(attn_weights, x) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) + + y = self.identity2(y) + x = x * y + x = self.identity3(x) + + x = self.out_proj(x) + x = self.whiten2(x) + return x + + def streaming_forward( + self, + x: Tensor, + attn_weights: Tensor, + cached_x: Tensor, + left_context_len: int, + ) -> Tuple[Tensor, Tensor]: + """. + Args: + x: a Tensor of shape (seq_len, batch_size, num_channels) + attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) + cached_x: left context, a Tensor of shape + (num_heads, batch_size, left_context_len, head_dim) + left_context_len: number of left context frames. + Returns: + - a Tensor with the same shape as x + - updated left context with same shape as cached_x + """ + x = self.in_proj(x) + + (seq_len, batch_size, _) = x.shape + hidden_channels = self.hidden_channels + + s, x, y = x.chunk(3, dim=-1) + + # s will go through tanh. + s = self.tanh(s) + + s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) + x = x * s + + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == (num_heads, batch_size, seq_len, left_context_len + seq_len) + + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, head_dim) + + # Pad cached tensor + assert cached_x.shape[2] == left_context_len, (cached_x.shape[2], left_context_len) + x_pad = torch.cat([cached_x, x], dim=2) + # Update cached tensor + cached_x = x_pad[:, :, -left_context_len:, :] + + x = torch.matmul(attn_weights, x_pad) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) + + x = x * y + + x = self.out_proj(x) + return x, cached_x + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Zipformer2 model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + def __init__( + self, channels: int, kernel_size: int, causal: bool, + ) -> None: + """Construct a ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + bottleneck_dim = channels + self.causal = causal + + self.in_proj = nn.Linear( + channels, 2 * bottleneck_dim, + ) + # the gradients on in_proj are a little noisy, likely to do with the + # sigmoid in glu. + + # after in_proj we put x through a gated linear unit (nn.functional.glu). + # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, + # but sometimes, for some reason, for layer 0 the rms ends up being very large, + # between 50 and 100 for different channels. This will cause very peaky and + # sparse derivatives for the sigmoid gating function, which will tend to make + # the loss function not learn effectively. (for most layers the average absolute values + # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, + # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different + # layers, which likely breaks down as 0.5 for the "linear" half and + # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we + # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, + # it will be in a better position to start learning something, i.e. to latch onto + # the correct range. + self.balancer1 = Balancer( + bottleneck_dim, channel_dim=-1, + min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)), + max_positive=1.0, + min_abs=1.5, + max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0), + ) + + self.activation1 = Identity() # for diagnostics + + self.sigmoid = nn.Sigmoid() + + self.activation2 = Identity() # for diagnostics + + assert kernel_size % 2 == 1 + + self.depthwise_conv = ChunkCausalDepthwiseConv1d( + channels=bottleneck_dim, + kernel_size=kernel_size) if causal else nn.Conv1d( + in_channels=bottleneck_dim, + out_channels=bottleneck_dim, + groups=bottleneck_dim, + kernel_size=kernel_size, + padding=kernel_size // 2) + + self.balancer2 = Balancer( + bottleneck_dim, channel_dim=1, + min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)), + max_positive=1.0, + min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)), + max_abs=10.0, + ) + + self.whiten = Whiten(num_groups=1, + whitening_limit=_whitening_schedule(7.5), + prob=(0.025, 0.25), + grad_scale=0.01) + + self.out_proj = ActivationDropoutAndLinear( + bottleneck_dim, channels, activation='SwooshR', + dropout_p=0.0, initial_scale=0.05, + ) + + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + chunk_size: int = -1, + ) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains True in masked positions. + + Returns: + Tensor: Output tensor (#time, batch, channels). + + """ + + x = self.in_proj(x) # (time, batch, 2*channels) + + x, s = x.chunk(2, dim=-1) + s = self.balancer1(s) + s = self.sigmoid(s) + x = self.activation1(x) # identity. + x = x * s + x = self.activation2(x) # identity + + # (time, batch, channels) + + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + if not torch.jit.is_scripting() and not torch.jit.is_tracing() and chunk_size >= 0: + # Not support exporting a model for simulated streaming decoding + assert self.causal, "Must initialize model with causal=True if you use chunk_size" + x = self.depthwise_conv(x, chunk_size=chunk_size) + else: + x = self.depthwise_conv(x) + + x = self.balancer2(x) + x = x.permute(2, 0, 1) # (time, batch, channels) + + x = self.whiten(x) # (time, batch, channels) + x = self.out_proj(x) # (time, batch, channels) + + return x + + def streaming_forward( + self, + x: Tensor, + cache: Tensor, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor]: + """Compute convolution module in streaming forward mode. + + Args: + x: Input tensor (#time, batch, channels). + cache: cached left context for depthwise_conv of shape + (#batch, channels, left_pad) + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains True in masked positions. + + Returns: + - Output tensor (#time, batch, channels). + - Updated cache (#batch, channels, left_pad) + """ + + x = self.in_proj(x) # (time, batch, 2*channels) + + x, s = x.chunk(2, dim=2) + s = self.sigmoid(s) + x = x * s + # (time, batch, channels) + + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + x, cache = self.depthwise_conv.streaming_forward(x, cache=cache) + + x = x.permute(2, 0, 1) # (time, batch, channels) + + x = self.out_proj(x) # (time, batch, channels) + + return x, cache + + +class ScalarMultiply(nn.Module): + def __init__(self, scale: float): + super().__init__() + self.scale = scale + + def forward(self, x): + return x * self.scale + + +def _test_zipformer_main(causal: bool = False): + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + + c = Zipformer2( + encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4), + causal=causal, + chunk_size=(4,) if causal else (-1,), + left_context_frames=(64,) + ) + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + f = c( + torch.randn(seq_len, batch_size, 64), + torch.full((batch_size,), seq_len, dtype=torch.int64), + ) + f[0].sum().backward() + c.eval() + f = c( + torch.randn(seq_len, batch_size, 64), + torch.full((batch_size,), seq_len, dtype=torch.int64), + ) + f # to remove flake8 warnings + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_zipformer_main(False) + _test_zipformer_main(True)