From 947ae0a73cf623ae5bbf902d1f51752068d52136 Mon Sep 17 00:00:00 2001 From: AmirHussein96 Date: Sat, 13 Sep 2025 09:57:15 -0400 Subject: [PATCH] replace files with symbolic links --- egs/iwslt22_ta/ASR/local/cer.py | 59 +- egs/iwslt22_ta/ASR/local/compile_hlg.py | 159 - egs/iwslt22_ta/ASR/local/compute_fbank_gpu.py | 4 +- .../ASR/local/compute_fbank_musan.py | 110 +- .../convert_transcript_words_to_tokens.py | 107 - egs/iwslt22_ta/ASR/local/cuts_validate.py | 110 +- .../ASR/local/display_manifest_statistics.py | 98 +- egs/iwslt22_ta/ASR/local/download_lm.py | 97 - egs/iwslt22_ta/ASR/local/filter_cuts.py | 1 + .../ASR/local/generate_unique_lexicon.py | 101 +- egs/iwslt22_ta/ASR/local/prep_lexicon.sh | 19 +- egs/iwslt22_ta/ASR/local/prepare_lang.py | 415 +-- egs/iwslt22_ta/ASR/local/prepare_lang_bpe.py | 256 +- egs/iwslt22_ta/ASR/local/prepare_lexicon.py | 39 - .../ASR/local/prepare_transcripts.py | 56 +- egs/iwslt22_ta/ASR/local/test_prepare_lang.py | 107 +- egs/iwslt22_ta/ASR/local/train_bpe_model.py | 99 +- egs/iwslt22_ta/ASR/local/validate_manifest.py | 1 + .../asr_finetune_datamodule.py | 422 --- .../beam_search.py | 2086 +----------- .../beam_search_old.py | 977 ------ .../decode_stream.py | 147 +- .../pruned_transducer_stateless5/decoder.py | 104 +- .../encoder_interface.py | 44 +- .../pruned_transducer_stateless5/export.py | 244 +- .../pruned_transducer_stateless5/joiner.py | 75 +- .../ASR/pruned_transducer_stateless5/model.py | 208 +- .../ASR/pruned_transducer_stateless5/optim.py | 332 +- .../pruned_transducer_stateless5/scaling.py | 720 +--- .../scaling_converter.py | 315 +- .../streaming_beam_search.py | 283 +- .../streaming_decode.py | 609 +--- .../ASR/pruned_transducer_stateless5/train.py | 23 +- egs/iwslt22_ta/ASR/shared | 1 + egs/iwslt22_ta/ASR/zipformer/beam_search.py | 3016 +---------------- egs/iwslt22_ta/ASR/zipformer/decoder.py | 124 +- egs/iwslt22_ta/ASR/zipformer/export.py | 524 +-- .../ASR/zipformer/generate_averaged_model.py | 203 +- .../ASR/zipformer/jit_pretrained.py | 273 +- .../ASR/zipformer/jit_pretrained_streaming.py | 270 +- egs/iwslt22_ta/ASR/zipformer/joiner.py | 67 +- egs/iwslt22_ta/ASR/zipformer/model.py | 490 +-- egs/iwslt22_ta/ASR/zipformer/optim.py | 1174 +------ egs/iwslt22_ta/ASR/zipformer/pretrained.py | 383 +-- egs/iwslt22_ta/ASR/zipformer/scaling.py | 1798 +--------- .../ASR/zipformer/scaling_converter.py | 83 +- .../ASR/zipformer/streaming_beam_search.py | 283 +- .../ASR/zipformer/streaming_decode.py | 877 +---- egs/iwslt22_ta/ASR/zipformer/subsampling.py | 408 +-- egs/iwslt22_ta/ASR/zipformer/zipformer.py | 2238 +----------- egs/iwslt22_ta/ST/README.md | 8 +- egs/iwslt22_ta/ST/RESULTS.md | 54 +- egs/iwslt22_ta/ST/local/compile_hlg.py | 159 - egs/iwslt22_ta/ST/local/compute_fbank_gpu.py | 4 +- .../ST/local/compute_fbank_musan.py | 110 +- .../ST/local/display_manifest_statistics.py | 98 +- egs/iwslt22_ta/ST/local/download_lm.py | 97 - .../ST/local/generate_unique_lexicon.py | 101 +- egs/iwslt22_ta/ST/local/prepare_lang.py | 415 +-- egs/iwslt22_ta/ST/local/prepare_lang_bpe.py | 256 +- .../ST/local/prepare_transcripts.py | 3 +- egs/iwslt22_ta/ST/local/test_prepare_lang.py | 107 +- egs/iwslt22_ta/ST/local/train_bpe_model.py | 99 +- .../asr_finetune_datamodule.py | 422 --- .../beam_search.py | 2086 +----------- .../beam_search_old.py | 977 ------ .../dataloader_invest.py | 93 - .../{decode_st.py => decode.py} | 0 .../decode_asr.py | 960 ------ .../decode_stream.py | 147 +- .../pruned_transducer_stateless5/decoder.py | 104 +- .../encoder_interface.py | 44 +- .../ST/pruned_transducer_stateless5/export.py | 244 +- .../ST/pruned_transducer_stateless5/joiner.py | 75 +- .../ST/pruned_transducer_stateless5/model.py | 208 +- .../ST/pruned_transducer_stateless5/optim.py | 332 +- .../pruned_transducer_stateless5/scaling.py | 720 +--- .../scaling_converter.py | 315 +- .../streaming_beam_search.py | 283 +- .../streaming_decode.py | 609 +--- .../{train_st.py => train.py} | 21 +- .../train_analysis.py | 1202 ------- .../pruned_transducer_stateless5/train_asr.py | 1301 ------- egs/iwslt22_ta/ST/shared | 1 + egs/iwslt22_ta/ST/zipformer/asr_datamodule.py | 22 +- .../ST/zipformer/{decode_st.py => decode.py} | 94 +- egs/iwslt22_ta/ST/zipformer/decode_asr.py | 852 ----- egs/iwslt22_ta/ST/zipformer/decoder.py | 30 +- .../ST/zipformer/encoder_interface.py | 44 +- egs/iwslt22_ta/ST/zipformer/joiner.py | 67 +- egs/iwslt22_ta/ST/zipformer/model.py | 4 +- egs/iwslt22_ta/ST/zipformer/optim.py | 1174 +------ egs/iwslt22_ta/ST/zipformer/profile.py | 177 +- egs/iwslt22_ta/ST/zipformer/scaling.py | 1798 +--------- .../ST/zipformer/scaling_converter.py | 83 +- egs/iwslt22_ta/ST/zipformer/subsampling.py | 408 +-- egs/iwslt22_ta/ST/zipformer/train.py | 234 +- egs/iwslt22_ta/ST/zipformer/train_asr.py | 1417 -------- egs/iwslt22_ta/ST/zipformer/train_st.py | 1422 -------- egs/iwslt22_ta/ST/zipformer/zipformer.py | 2238 +----------- .../ASR/local/compute_fbank_musan.py | 0 egs/librispeech/ASR/zipformer/train.py | 3 +- 102 files changed, 344 insertions(+), 42017 deletions(-) mode change 100644 => 120000 egs/iwslt22_ta/ASR/local/cer.py delete mode 100755 egs/iwslt22_ta/ASR/local/compile_hlg.py mode change 100755 => 120000 egs/iwslt22_ta/ASR/local/compute_fbank_musan.py delete mode 100755 egs/iwslt22_ta/ASR/local/convert_transcript_words_to_tokens.py mode change 100755 => 120000 egs/iwslt22_ta/ASR/local/cuts_validate.py mode change 100755 => 120000 egs/iwslt22_ta/ASR/local/display_manifest_statistics.py delete mode 100755 egs/iwslt22_ta/ASR/local/download_lm.py create mode 120000 egs/iwslt22_ta/ASR/local/filter_cuts.py mode change 100755 => 120000 egs/iwslt22_ta/ASR/local/generate_unique_lexicon.py mode change 100755 => 120000 egs/iwslt22_ta/ASR/local/prep_lexicon.sh mode change 100755 => 120000 egs/iwslt22_ta/ASR/local/prepare_lang.py mode change 100755 => 120000 egs/iwslt22_ta/ASR/local/prepare_lang_bpe.py delete mode 100755 egs/iwslt22_ta/ASR/local/prepare_lexicon.py mode change 100755 => 120000 egs/iwslt22_ta/ASR/local/prepare_transcripts.py mode change 100755 => 120000 egs/iwslt22_ta/ASR/local/test_prepare_lang.py mode change 100755 => 120000 egs/iwslt22_ta/ASR/local/train_bpe_model.py create mode 120000 egs/iwslt22_ta/ASR/local/validate_manifest.py delete mode 100644 egs/iwslt22_ta/ASR/pruned_transducer_stateless5/asr_finetune_datamodule.py mode change 100644 => 120000 egs/iwslt22_ta/ASR/pruned_transducer_stateless5/beam_search.py delete mode 100644 egs/iwslt22_ta/ASR/pruned_transducer_stateless5/beam_search_old.py mode change 100755 => 120000 egs/iwslt22_ta/ASR/pruned_transducer_stateless5/decode_stream.py mode change 100644 => 120000 egs/iwslt22_ta/ASR/pruned_transducer_stateless5/decoder.py mode change 100644 => 120000 egs/iwslt22_ta/ASR/pruned_transducer_stateless5/encoder_interface.py mode change 100755 => 120000 egs/iwslt22_ta/ASR/pruned_transducer_stateless5/export.py mode change 100644 => 120000 egs/iwslt22_ta/ASR/pruned_transducer_stateless5/joiner.py mode change 100644 => 120000 egs/iwslt22_ta/ASR/pruned_transducer_stateless5/model.py mode change 100644 => 120000 egs/iwslt22_ta/ASR/pruned_transducer_stateless5/optim.py mode change 100644 => 120000 egs/iwslt22_ta/ASR/pruned_transducer_stateless5/scaling.py mode change 100644 => 120000 egs/iwslt22_ta/ASR/pruned_transducer_stateless5/scaling_converter.py mode change 100644 => 120000 egs/iwslt22_ta/ASR/pruned_transducer_stateless5/streaming_beam_search.py mode change 100755 => 120000 egs/iwslt22_ta/ASR/pruned_transducer_stateless5/streaming_decode.py create mode 120000 egs/iwslt22_ta/ASR/shared mode change 100644 => 120000 egs/iwslt22_ta/ASR/zipformer/beam_search.py mode change 100644 => 120000 egs/iwslt22_ta/ASR/zipformer/decoder.py mode change 100755 => 120000 egs/iwslt22_ta/ASR/zipformer/export.py mode change 100755 => 120000 egs/iwslt22_ta/ASR/zipformer/generate_averaged_model.py mode change 100755 => 120000 egs/iwslt22_ta/ASR/zipformer/jit_pretrained.py mode change 100755 => 120000 egs/iwslt22_ta/ASR/zipformer/jit_pretrained_streaming.py mode change 100644 => 120000 egs/iwslt22_ta/ASR/zipformer/joiner.py mode change 100644 => 120000 egs/iwslt22_ta/ASR/zipformer/model.py mode change 100644 => 120000 egs/iwslt22_ta/ASR/zipformer/optim.py mode change 100755 => 120000 egs/iwslt22_ta/ASR/zipformer/pretrained.py mode change 100644 => 120000 egs/iwslt22_ta/ASR/zipformer/scaling.py mode change 100644 => 120000 egs/iwslt22_ta/ASR/zipformer/scaling_converter.py mode change 100644 => 120000 egs/iwslt22_ta/ASR/zipformer/streaming_beam_search.py mode change 100755 => 120000 egs/iwslt22_ta/ASR/zipformer/streaming_decode.py mode change 100644 => 120000 egs/iwslt22_ta/ASR/zipformer/subsampling.py mode change 100644 => 120000 egs/iwslt22_ta/ASR/zipformer/zipformer.py delete mode 100755 egs/iwslt22_ta/ST/local/compile_hlg.py mode change 100755 => 120000 egs/iwslt22_ta/ST/local/compute_fbank_musan.py mode change 100755 => 120000 egs/iwslt22_ta/ST/local/display_manifest_statistics.py delete mode 100755 egs/iwslt22_ta/ST/local/download_lm.py mode change 100755 => 120000 egs/iwslt22_ta/ST/local/generate_unique_lexicon.py mode change 100755 => 120000 egs/iwslt22_ta/ST/local/prepare_lang.py mode change 100755 => 120000 egs/iwslt22_ta/ST/local/prepare_lang_bpe.py mode change 100755 => 120000 egs/iwslt22_ta/ST/local/test_prepare_lang.py mode change 100755 => 120000 egs/iwslt22_ta/ST/local/train_bpe_model.py delete mode 100644 egs/iwslt22_ta/ST/pruned_transducer_stateless5/asr_finetune_datamodule.py mode change 100644 => 120000 egs/iwslt22_ta/ST/pruned_transducer_stateless5/beam_search.py delete mode 100644 egs/iwslt22_ta/ST/pruned_transducer_stateless5/beam_search_old.py delete mode 100644 egs/iwslt22_ta/ST/pruned_transducer_stateless5/dataloader_invest.py rename egs/iwslt22_ta/ST/pruned_transducer_stateless5/{decode_st.py => decode.py} (100%) delete mode 100755 egs/iwslt22_ta/ST/pruned_transducer_stateless5/decode_asr.py mode change 100755 => 120000 egs/iwslt22_ta/ST/pruned_transducer_stateless5/decode_stream.py mode change 100644 => 120000 egs/iwslt22_ta/ST/pruned_transducer_stateless5/decoder.py mode change 100644 => 120000 egs/iwslt22_ta/ST/pruned_transducer_stateless5/encoder_interface.py mode change 100755 => 120000 egs/iwslt22_ta/ST/pruned_transducer_stateless5/export.py mode change 100644 => 120000 egs/iwslt22_ta/ST/pruned_transducer_stateless5/joiner.py mode change 100644 => 120000 egs/iwslt22_ta/ST/pruned_transducer_stateless5/model.py mode change 100644 => 120000 egs/iwslt22_ta/ST/pruned_transducer_stateless5/optim.py mode change 100644 => 120000 egs/iwslt22_ta/ST/pruned_transducer_stateless5/scaling.py mode change 100644 => 120000 egs/iwslt22_ta/ST/pruned_transducer_stateless5/scaling_converter.py mode change 100644 => 120000 egs/iwslt22_ta/ST/pruned_transducer_stateless5/streaming_beam_search.py mode change 100755 => 120000 egs/iwslt22_ta/ST/pruned_transducer_stateless5/streaming_decode.py rename egs/iwslt22_ta/ST/pruned_transducer_stateless5/{train_st.py => train.py} (98%) delete mode 100755 egs/iwslt22_ta/ST/pruned_transducer_stateless5/train_analysis.py delete mode 100755 egs/iwslt22_ta/ST/pruned_transducer_stateless5/train_asr.py create mode 120000 egs/iwslt22_ta/ST/shared rename egs/iwslt22_ta/ST/zipformer/{decode_st.py => decode.py} (94%) delete mode 100755 egs/iwslt22_ta/ST/zipformer/decode_asr.py mode change 100644 => 120000 egs/iwslt22_ta/ST/zipformer/encoder_interface.py mode change 100644 => 120000 egs/iwslt22_ta/ST/zipformer/joiner.py mode change 100644 => 120000 egs/iwslt22_ta/ST/zipformer/optim.py mode change 100755 => 120000 egs/iwslt22_ta/ST/zipformer/profile.py mode change 100644 => 120000 egs/iwslt22_ta/ST/zipformer/scaling.py mode change 100644 => 120000 egs/iwslt22_ta/ST/zipformer/scaling_converter.py mode change 100644 => 120000 egs/iwslt22_ta/ST/zipformer/subsampling.py delete mode 100755 egs/iwslt22_ta/ST/zipformer/train_asr.py delete mode 100755 egs/iwslt22_ta/ST/zipformer/train_st.py mode change 100644 => 120000 egs/iwslt22_ta/ST/zipformer/zipformer.py mode change 100755 => 100644 egs/librispeech/ASR/local/compute_fbank_musan.py diff --git a/egs/iwslt22_ta/ASR/local/cer.py b/egs/iwslt22_ta/ASR/local/cer.py deleted file mode 100644 index 3635d2e22..000000000 --- a/egs/iwslt22_ta/ASR/local/cer.py +++ /dev/null @@ -1,58 +0,0 @@ -#!/usr/bin/python -# Copyright 2023 Johns Hopkins University (Amir Hussein) -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -""" -This script computes CER for the decodings generated by icefall recipe -""" - -import argparse -import jiwer -import os - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--dec-file", - type=str, - help="file with decoded text" - ) - - return parser - -def cer_(file): - hyp = [] - ref = [] - cer_results = 0 - ref_lens = 0 - with open(file, 'r', encoding='utf-8') as dec: - - for line in dec: - id, target = line.split('\t') - id = id[0:-2] - target, txt = target.split("=") - if target == 'ref': - words = txt.strip().strip('[]').split(', ') - word_list = [word.strip("'") for word in words] - ref.append(" ".join(word_list)) - elif target == 'hyp': - words = txt.strip().strip('[]').split(', ') - word_list = [word.strip("'") for word in words] - hyp.append(" ".join(word_list)) - for h, r in zip(hyp, ref): - #breakpoint() - cer_results += (jiwer.cer(r, h)*len(r)) - ref_lens += len(r) - print(os.path.basename(file)) - print(cer_results/ref_lens) - - - - -def main(): - parse = get_args() - args = parse.parse_args() - cer_(args.dec_file) - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/local/cer.py b/egs/iwslt22_ta/ASR/local/cer.py new file mode 120000 index 000000000..ab9f2ef24 --- /dev/null +++ b/egs/iwslt22_ta/ASR/local/cer.py @@ -0,0 +1 @@ +../../ST/local/cer.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/local/compile_hlg.py b/egs/iwslt22_ta/ASR/local/compile_hlg.py deleted file mode 100755 index 9a35750e0..000000000 --- a/egs/iwslt22_ta/ASR/local/compile_hlg.py +++ /dev/null @@ -1,159 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This script takes as input lang_dir and generates HLG from - - - H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt - - L, the lexicon, built from lang_dir/L_disambig.pt - - Caution: We use a lexicon that contains disambiguation symbols - - - G, the LM, built from data/lm/G_3_gram.fst.txt - -The generated HLG is saved in $lang_dir/HLG.pt -""" -import argparse -import logging -from pathlib import Path - -import k2 -import torch - -from icefall.lexicon import Lexicon - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", - type=str, - help="""Input and output directory. - """, - ) - - return parser.parse_args() - - -def compile_HLG(lang_dir: str) -> k2.Fsa: - """ - Args: - lang_dir: - The language directory, e.g., data/lang_phone or data/lang_bpe_5000. - - Return: - An FSA representing HLG. - """ - lexicon = Lexicon(lang_dir) - max_token_id = max(lexicon.tokens) - logging.info(f"Building ctc_topo. max_token_id: {max_token_id}") - H = k2.ctc_topo(max_token_id) - L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) - - if Path("data/lm/G_3_gram.pt").is_file(): - logging.info("Loading pre-compiled G_3_gram") - d = torch.load("data/lm/G_3_gram.pt") - G = k2.Fsa.from_dict(d) - else: - logging.info("Loading G_3_gram.fst.txt") - with open("data/lm/G_3_gram.fst.txt") as f: - G = k2.Fsa.from_openfst(f.read(), acceptor=False) - torch.save(G.as_dict(), "data/lm/G_3_gram.pt") - - first_token_disambig_id = lexicon.token_table["#0"] - first_word_disambig_id = lexicon.word_table["#0"] - - L = k2.arc_sort(L) - G = k2.arc_sort(G) - - logging.info("Intersecting L and G") - LG = k2.compose(L, G) - logging.info(f"LG shape: {LG.shape}") - - logging.info("Connecting LG") - LG = k2.connect(LG) - logging.info(f"LG shape after k2.connect: {LG.shape}") - - logging.info(type(LG.aux_labels)) - logging.info("Determinizing LG") - - LG = k2.determinize(LG) - logging.info(type(LG.aux_labels)) - - logging.info("Connecting LG after k2.determinize") - LG = k2.connect(LG) - - logging.info("Removing disambiguation symbols on LG") - - LG.labels[LG.labels >= first_token_disambig_id] = 0 - # See https://github.com/k2-fsa/k2/issues/874 - # for why we need to set LG.properties to None - LG.__dict__["_properties"] = None - - assert isinstance(LG.aux_labels, k2.RaggedTensor) - LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0 - - LG = k2.remove_epsilon(LG) - logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") - - LG = k2.connect(LG) - LG.aux_labels = LG.aux_labels.remove_values_eq(0) - - logging.info("Arc sorting LG") - LG = k2.arc_sort(LG) - - logging.info("Composing H and LG") - # CAUTION: The name of the inner_labels is fixed - # to `tokens`. If you want to change it, please - # also change other places in icefall that are using - # it. - HLG = k2.compose(H, LG, inner_labels="tokens") - - logging.info("Connecting LG") - HLG = k2.connect(HLG) - - logging.info("Arc sorting LG") - HLG = k2.arc_sort(HLG) - logging.info(f"HLG.shape: {HLG.shape}") - - return HLG - - -def main(): - args = get_args() - lang_dir = Path(args.lang_dir) - - if (lang_dir / "HLG.pt").is_file(): - logging.info(f"{lang_dir}/HLG.pt already exists - skipping") - return - - logging.info(f"Processing {lang_dir}") - - HLG = compile_HLG(lang_dir) - logging.info(f"Saving HLG.pt to {lang_dir}") - torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt") - - -if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/iwslt22_ta/ASR/local/compute_fbank_gpu.py b/egs/iwslt22_ta/ASR/local/compute_fbank_gpu.py index 84e17addb..05ed0a74a 100755 --- a/egs/iwslt22_ta/ASR/local/compute_fbank_gpu.py +++ b/egs/iwslt22_ta/ASR/local/compute_fbank_gpu.py @@ -45,8 +45,6 @@ from lhotse.features.kaldifeat import ( # it wastes a lot of CPU and slow things down. # Do this outside of main() in case it needs to take effect # even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) def get_args(): parser = argparse.ArgumentParser() @@ -91,7 +89,7 @@ def compute_fbank_gpu(args): "dev", ) manifests = read_manifests_if_cached( - prefix="iwslt", dataset_parts=dataset_parts, output_dir=src_dir + prefix="iwslt-ta", dataset_parts=dataset_parts, output_dir=src_dir ) assert manifests is not None diff --git a/egs/iwslt22_ta/ASR/local/compute_fbank_musan.py b/egs/iwslt22_ta/ASR/local/compute_fbank_musan.py deleted file mode 100755 index 48905de6f..000000000 --- a/egs/iwslt22_ta/ASR/local/compute_fbank_musan.py +++ /dev/null @@ -1,109 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the musan dataset. -It looks for manifests in the directory data/manifests. - -The generated fbank features are saved in data/fbank. -""" - -import logging -import os -from pathlib import Path - -import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, MonoCut, combine -from lhotse.recipes.utils import read_manifests_if_cached - -from icefall.utils import get_executor - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def is_cut_long(c: MonoCut) -> bool: - return c.duration > 5 - - -def compute_fbank_musan(): - src_dir = Path("data/manifests") - output_dir = Path("data/fbank") - num_jobs = min(30, os.cpu_count()) - num_mel_bins = 80 - - dataset_parts = ( - "music", - "speech", - "noise", - ) - prefix = "musan" - suffix = "jsonl.gz" - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, - output_dir=src_dir, - prefix=prefix, - suffix=suffix, - ) - assert manifests is not None - - assert len(manifests) == len(dataset_parts), ( - len(manifests), - len(dataset_parts), - list(manifests.keys()), - dataset_parts, - ) - - musan_cuts_path = output_dir / "musan_cuts.jsonl.gz" - - if musan_cuts_path.is_file(): - logging.info(f"{musan_cuts_path} already exists - skipping") - return - - logging.info("Extracting features for Musan") - - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) - - with get_executor() as ex: # Initialize the executor only once. - # create chunks of Musan with duration 5 - 10 seconds - musan_cuts = ( - CutSet.from_manifests( - recordings=combine(part["recordings"] for part in manifests.values()) - ) - .cut_into_windows(10.0) - .filter(is_cut_long) - .compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/musan_feats", - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - ) - musan_cuts.to_file(musan_cuts_path) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - compute_fbank_musan() \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/local/compute_fbank_musan.py b/egs/iwslt22_ta/ASR/local/compute_fbank_musan.py new file mode 120000 index 000000000..5833f2484 --- /dev/null +++ b/egs/iwslt22_ta/ASR/local/compute_fbank_musan.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/local/convert_transcript_words_to_tokens.py b/egs/iwslt22_ta/ASR/local/convert_transcript_words_to_tokens.py deleted file mode 100755 index 133499c8b..000000000 --- a/egs/iwslt22_ta/ASR/local/convert_transcript_words_to_tokens.py +++ /dev/null @@ -1,107 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -""" -Convert a transcript file containing words to a corpus file containing tokens -for LM training with the help of a lexicon. - -If the lexicon contains phones, the resulting LM will be a phone LM; If the -lexicon contains word pieces, the resulting LM will be a word piece LM. - -If a word has multiple pronunciations, the one that appears first in the lexicon -is kept; others are removed. - -If the input transcript is: - - hello zoo world hello - world zoo - foo zoo world hellO - -and if the lexicon is - - SPN - hello h e l l o 2 - hello h e l l o - world w o r l d - zoo z o o - -Then the output is - - h e l l o 2 z o o w o r l d h e l l o 2 - w o r l d z o o - SPN z o o w o r l d SPN -""" - -import argparse -from pathlib import Path -from typing import Dict, List - -from generate_unique_lexicon import filter_multiple_pronunications - -from icefall.lexicon import read_lexicon - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--transcript", - type=str, - help="The input transcript file." - "We assume that the transcript file consists of " - "lines. Each line consists of space separated words.", - ) - parser.add_argument("--lexicon", type=str, help="The input lexicon file.") - parser.add_argument( - "--oov", type=str, default="", help="The OOV word." - ) - - return parser.parse_args() - - -def process_line( - lexicon: Dict[str, List[str]], line: str, oov_token: str -) -> None: - """ - Args: - lexicon: - A dict containing pronunciations. Its keys are words and values - are pronunciations (i.e., tokens). - line: - A line of transcript consisting of space(s) separated words. - oov_token: - The pronunciation of the oov word if a word in `line` is not present - in the lexicon. - Returns: - Return None. - """ - s = "" - words = line.strip().split() - for i, w in enumerate(words): - tokens = lexicon.get(w, oov_token) - s += " ".join(tokens) - s += " " - print(s.strip()) - - -def main(): - args = get_args() - assert Path(args.lexicon).is_file() - assert Path(args.transcript).is_file() - assert len(args.oov) > 0 - - # Only the first pronunciation of a word is kept - lexicon = filter_multiple_pronunications(read_lexicon(args.lexicon)) - - lexicon = dict(lexicon) - - assert args.oov in lexicon - - oov_token = lexicon[args.oov] - - with open(args.transcript) as f: - for line in f: - process_line(lexicon=lexicon, line=line, oov_token=oov_token) - - -if __name__ == "__main__": - main() diff --git a/egs/iwslt22_ta/ASR/local/cuts_validate.py b/egs/iwslt22_ta/ASR/local/cuts_validate.py deleted file mode 100755 index 879242b00..000000000 --- a/egs/iwslt22_ta/ASR/local/cuts_validate.py +++ /dev/null @@ -1,109 +0,0 @@ -#!/usr/bin/python -# Copyright 2023 Johns Hopkins University (Amir Hussein) -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -""" -This script helps validating the prepared manifests (recordings, supervisions) -and CutSets - -""" -from lhotse import RecordingSet, SupervisionSet, CutSet -import argparse -import logging -from lhotse.qa import fix_manifests, validate_recordings_and_supervisions -import pdb - - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--sup", - type=str, - default="", - help="Supervisions file", - ) - - parser.add_argument( - "--rec", - type=str, - default="", - help="Recordings file", - ) - parser.add_argument( - "--cut", - type=str, - default="", - help="Cutset file", - ) - parser.add_argument( - "--savecut", - type=str, - default="", - help="name of the cutset to be saved", - ) - - - - return parser - -def valid_asr(cut): - tol = 2e-3 - i=0 - total_dur = 0 - for c in cut: - if c.supervisions != []: - if c.supervisions[0].end > c.duration + tol: - - logging.info(f"Supervision beyond the cut. Cut number: {i}") - total_dur += c.duration - logging.info(f"id: {c.id}, sup_end: {c.supervisions[0].end}, dur: {c.duration}, source {c.recording.sources[0].source}") - elif c.supervisions[0].start < -tol: - logging.info(f"Supervision starts before the cut. Cut number: {i}") - logging.info(f"id: {c.id}, sup_start: {c.supervisions[0].start}, dur: {c.duration}, source {c.recording.sources[0].source}") - else: - continue - else: - logging.info("Empty supervision") - logging.info(f"id: {c.id}") - i += 1 - logging.info(f"filtered duration: {total_dur}") - - -def main(): - - parser = get_parser() - args = parser.parse_args() - if args.cut != "": - cuts = CutSet.from_file(args.cut) - else: - - recordings = RecordingSet.from_file(args.rec) - supervisions = SupervisionSet.from_file(args.sup) - logging.info("Example from supervisions:") - logging.info(supervisions[0]) - logging.info("Example from recordings") - print(recordings[0]) - logging.info("Fixing manifests") - recordings, supervisions = fix_manifests(recordings, supervisions) - - logging.info("Validating manifests") - validate_recordings_and_supervisions(recordings, supervisions) - - cuts = CutSet.from_manifests(recordings= recordings, supervisions=supervisions,) - - cuts = cuts.trim_to_supervisions(keep_overlapping=False, keep_all_channels=False) - logging.info("Example from cut:") - print(cuts[100]) - breakpoint() - cuts.describe() - logging.info("Validating manifests for ASR") - valid_asr(cuts) - if args.savecut != "": - cuts.to_file(args.savecut) - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/local/cuts_validate.py b/egs/iwslt22_ta/ASR/local/cuts_validate.py new file mode 120000 index 000000000..98bfc9c72 --- /dev/null +++ b/egs/iwslt22_ta/ASR/local/cuts_validate.py @@ -0,0 +1 @@ +../../ST/local/cuts_validate.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/local/display_manifest_statistics.py b/egs/iwslt22_ta/ASR/local/display_manifest_statistics.py deleted file mode 100755 index d3e224905..000000000 --- a/egs/iwslt22_ta/ASR/local/display_manifest_statistics.py +++ /dev/null @@ -1,97 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file displays duration statistics of utterances in a manifest. -You can use the displayed value to choose minimum/maximum duration -to remove short and long utterances during the training. - -See the function `remove_short_and_long_utt()` in transducer/train.py -for usage. -""" - - -from lhotse import load_manifest - - -def main(): - # path = "./data/fbank/cuts_train.jsonl.gz" - path = "./data/fbank/cuts_dev.jsonl.gz" - # path = "./data/fbank/cuts_test.jsonl.gz" - - cuts = load_manifest(path) - cuts.describe() - - -if __name__ == "__main__": - main() - -""" -# train - -Cuts count: 1125309 -Total duration (hours): 3403.9 -Speech duration (hours): 3403.9 (100.0%) -*** -Duration statistics (seconds): -mean 10.9 -std 10.1 -min 0.2 -25% 5.2 -50% 7.8 -75% 12.7 -99% 52.0 -99.5% 65.1 -99.9% 99.5 -max 228.9 - - -# test -Cuts count: 5365 -Total duration (hours): 9.6 -Speech duration (hours): 9.6 (100.0%) -*** -Duration statistics (seconds): -mean 6.4 -std 1.5 -min 1.6 -25% 5.3 -50% 6.5 -75% 7.6 -99% 9.5 -99.5% 9.7 -99.9% 10.3 -max 12.4 - -# dev -Cuts count: 5002 -Total duration (hours): 8.5 -Speech duration (hours): 8.5 (100.0%) -*** -Duration statistics (seconds): -mean 6.1 -std 1.7 -min 1.5 -25% 4.8 -50% 6.2 -75% 7.4 -99% 9.5 -99.5% 9.7 -99.9% 10.1 -max 20.3 - -""" diff --git a/egs/iwslt22_ta/ASR/local/display_manifest_statistics.py b/egs/iwslt22_ta/ASR/local/display_manifest_statistics.py new file mode 120000 index 000000000..e99e43515 --- /dev/null +++ b/egs/iwslt22_ta/ASR/local/display_manifest_statistics.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/display_manifest_statistics.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/local/download_lm.py b/egs/iwslt22_ta/ASR/local/download_lm.py deleted file mode 100755 index 94d23afed..000000000 --- a/egs/iwslt22_ta/ASR/local/download_lm.py +++ /dev/null @@ -1,97 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file downloads the following LibriSpeech LM files: - - - 3-gram.pruned.1e-7.arpa.gz - - 4-gram.arpa.gz - - librispeech-vocab.txt - - librispeech-lexicon.txt - -from http://www.openslr.org/resources/11 -and save them in the user provided directory. - -Files are not re-downloaded if they already exist. - -Usage: - ./local/download_lm.py --out-dir ./download/lm -""" - -import argparse -import gzip -import logging -import os -import shutil -from pathlib import Path - -from lhotse.utils import urlretrieve_progress -from tqdm.auto import tqdm - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--out-dir", type=str, help="Output directory.") - - args = parser.parse_args() - return args - - -def main(out_dir: str): - url = "http://www.openslr.org/resources/11" - out_dir = Path(out_dir) - - files_to_download = ( - "3-gram.pruned.1e-7.arpa.gz", - "4-gram.arpa.gz", - "librispeech-vocab.txt", - "librispeech-lexicon.txt", - ) - - for f in tqdm(files_to_download, desc="Downloading LibriSpeech LM files"): - filename = out_dir / f - if filename.is_file() is False: - urlretrieve_progress( - f"{url}/{f}", - filename=filename, - desc=f"Downloading {filename}", - ) - else: - logging.info(f"{filename} already exists - skipping") - - if ".gz" in str(filename): - unzipped = Path(os.path.splitext(filename)[0]) - if unzipped.is_file() is False: - with gzip.open(filename, "rb") as f_in: - with open(unzipped, "wb") as f_out: - shutil.copyfileobj(f_in, f_out) - else: - logging.info(f"{unzipped} already exist - skipping") - - -if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) - - logging.basicConfig(format=formatter, level=logging.INFO) - - args = get_args() - logging.info(f"out_dir: {args.out_dir}") - - main(out_dir=args.out_dir) diff --git a/egs/iwslt22_ta/ASR/local/filter_cuts.py b/egs/iwslt22_ta/ASR/local/filter_cuts.py new file mode 120000 index 000000000..27aca1729 --- /dev/null +++ b/egs/iwslt22_ta/ASR/local/filter_cuts.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/filter_cuts.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/local/generate_unique_lexicon.py b/egs/iwslt22_ta/ASR/local/generate_unique_lexicon.py deleted file mode 100755 index 566c0743d..000000000 --- a/egs/iwslt22_ta/ASR/local/generate_unique_lexicon.py +++ /dev/null @@ -1,100 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file takes as input a lexicon.txt and output a new lexicon, -in which each word has a unique pronunciation. - -The way to do this is to keep only the first pronunciation of a word -in lexicon.txt. -""" - - -import argparse -import logging -from pathlib import Path -from typing import List, Tuple - -from icefall.lexicon import read_lexicon, write_lexicon - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", - type=str, - help="""Input and output directory. - It should contain a file lexicon.txt. - This file will generate a new file uniq_lexicon.txt - in it. - """, - ) - - return parser.parse_args() - - -def filter_multiple_pronunications( - lexicon: List[Tuple[str, List[str]]] -) -> List[Tuple[str, List[str]]]: - """Remove multiple pronunciations of words from a lexicon. - - If a word has more than one pronunciation in the lexicon, only - the first one is kept, while other pronunciations are removed - from the lexicon. - - Args: - lexicon: - The input lexicon, containing a list of (word, [p1, p2, ..., pn]), - where "p1, p2, ..., pn" are the pronunciations of the "word". - Returns: - Return a new lexicon where each word has a unique pronunciation. - """ - seen = set() - ans = [] - - for word, tokens in lexicon: - if word in seen: - continue - seen.add(word) - ans.append((word, tokens)) - return ans - - -def main(): - args = get_args() - lang_dir = Path(args.lang_dir) - - lexicon_filename = lang_dir / "lexicon.txt" - - in_lexicon = read_lexicon(lexicon_filename) - - out_lexicon = filter_multiple_pronunications(in_lexicon) - - write_lexicon(lang_dir / "uniq_lexicon.txt", out_lexicon) - - logging.info(f"Number of entries in lexicon.txt: {len(in_lexicon)}") - logging.info(f"Number of entries in uniq_lexicon.txt: {len(out_lexicon)}") - - -if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/iwslt22_ta/ASR/local/generate_unique_lexicon.py b/egs/iwslt22_ta/ASR/local/generate_unique_lexicon.py new file mode 120000 index 000000000..c0aea1403 --- /dev/null +++ b/egs/iwslt22_ta/ASR/local/generate_unique_lexicon.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/generate_unique_lexicon.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/local/prep_lexicon.sh b/egs/iwslt22_ta/ASR/local/prep_lexicon.sh deleted file mode 100755 index d394a0a22..000000000 --- a/egs/iwslt22_ta/ASR/local/prep_lexicon.sh +++ /dev/null @@ -1,18 +0,0 @@ -#!/usr/bin/env bash - -# Copyright 2022 QCRI (author: Amir Hussein) -# Apache 2.0 -# This script prepares the graphemic lexicon. - -dir=data/local/dict -stage=0 -lang_dir=$1 - -cat $lang_dir/transcript_words.txt | tr -s " " "\n" | sort -u > $lang_dir/uniq_words - -echo "$0: processing lexicon text and creating lexicon... $(date)." -# remove vowels and rare alef wasla -cat $lang_dir/uniq_words | sed -e 's:[FNKaui\~o\`]::g' -e 's:{:}:g' | sed -r '/^\s*$/d' | sort -u > $lang_dir/words.txt - - -echo "$0: Lexicon preparation succeeded" diff --git a/egs/iwslt22_ta/ASR/local/prep_lexicon.sh b/egs/iwslt22_ta/ASR/local/prep_lexicon.sh new file mode 120000 index 000000000..996ed086d --- /dev/null +++ b/egs/iwslt22_ta/ASR/local/prep_lexicon.sh @@ -0,0 +1 @@ +../../ST/local/prep_lexicon.sh \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/local/prepare_lang.py b/egs/iwslt22_ta/ASR/local/prepare_lang.py deleted file mode 100755 index 1f7120c99..000000000 --- a/egs/iwslt22_ta/ASR/local/prepare_lang.py +++ /dev/null @@ -1,414 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This script takes as input a lexicon file "data/lang_phone/lexicon.txt" -consisting of words and tokens (i.e., phones) and does the following: - -1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt - -2. Generate tokens.txt, the token table mapping a token to a unique integer. - -3. Generate words.txt, the word table mapping a word to a unique integer. - -4. Generate L.pt, in k2 format. It can be loaded by - - d = torch.load("L.pt") - lexicon = k2.Fsa.from_dict(d) - -5. Generate L_disambig.pt, in k2 format. -""" -import argparse -import math -from collections import defaultdict -from pathlib import Path -from typing import Any, Dict, List, Tuple - -import k2 -import torch - -from icefall.lexicon import read_lexicon, write_lexicon -from icefall.utils import str2bool - -Lexicon = List[Tuple[str, List[str]]] - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", - type=str, - help="""Input and output directory. - It should contain a file lexicon.txt. - Generated files by this script are saved into this directory. - """, - ) - - parser.add_argument( - "--debug", - type=str2bool, - default=False, - help="""True for debugging, which will generate - a visualization of the lexicon FST. - - Caution: If your lexicon contains hundreds of thousands - of lines, please set it to False! - """, - ) - - return parser.parse_args() - - -def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: - """Write a symbol to ID mapping to a file. - - Note: - No need to implement `read_mapping` as it can be done - through :func:`k2.SymbolTable.from_file`. - - Args: - filename: - Filename to save the mapping. - sym2id: - A dict mapping symbols to IDs. - Returns: - Return None. - """ - with open(filename, "w", encoding="utf-8") as f: - for sym, i in sym2id.items(): - f.write(f"{sym} {i}\n") - - -def get_tokens(lexicon: Lexicon) -> List[str]: - """Get tokens from a lexicon. - - Args: - lexicon: - It is the return value of :func:`read_lexicon`. - Returns: - Return a list of unique tokens. - """ - ans = set() - for _, tokens in lexicon: - ans.update(tokens) - sorted_ans = sorted(list(ans)) - return sorted_ans - - -def get_words(lexicon: Lexicon) -> List[str]: - """Get words from a lexicon. - - Args: - lexicon: - It is the return value of :func:`read_lexicon`. - Returns: - Return a list of unique words. - """ - ans = set() - for word, _ in lexicon: - ans.add(word) - sorted_ans = sorted(list(ans)) - return sorted_ans - - -def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]: - """It adds pseudo-token disambiguation symbols #1, #2 and so on - at the ends of tokens to ensure that all pronunciations are different, - and that none is a prefix of another. - - See also add_lex_disambig.pl from kaldi. - - Args: - lexicon: - It is returned by :func:`read_lexicon`. - Returns: - Return a tuple with two elements: - - - The output lexicon with disambiguation symbols - - The ID of the max disambiguation symbol that appears - in the lexicon - """ - - # (1) Work out the count of each token-sequence in the - # lexicon. - count = defaultdict(int) - for _, tokens in lexicon: - count[" ".join(tokens)] += 1 - - # (2) For each left sub-sequence of each token-sequence, note down - # that it exists (for identifying prefixes of longer strings). - issubseq = defaultdict(int) - for _, tokens in lexicon: - tokens = tokens.copy() - tokens.pop() - while tokens: - issubseq[" ".join(tokens)] = 1 - tokens.pop() - - # (3) For each entry in the lexicon: - # if the token sequence is unique and is not a - # prefix of another word, no disambig symbol. - # Else output #1, or #2, #3, ... if the same token-seq - # has already been assigned a disambig symbol. - ans = [] - - # We start with #1 since #0 has its own purpose - first_allowed_disambig = 1 - max_disambig = first_allowed_disambig - 1 - last_used_disambig_symbol_of = defaultdict(int) - - for word, tokens in lexicon: - tokenseq = " ".join(tokens) - assert tokenseq != "" - if issubseq[tokenseq] == 0 and count[tokenseq] == 1: - ans.append((word, tokens)) - continue - - cur_disambig = last_used_disambig_symbol_of[tokenseq] - if cur_disambig == 0: - cur_disambig = first_allowed_disambig - else: - cur_disambig += 1 - - if cur_disambig > max_disambig: - max_disambig = cur_disambig - last_used_disambig_symbol_of[tokenseq] = cur_disambig - tokenseq += f" #{cur_disambig}" - ans.append((word, tokenseq.split())) - return ans, max_disambig - - -def generate_id_map(symbols: List[str]) -> Dict[str, int]: - """Generate ID maps, i.e., map a symbol to a unique ID. - - Args: - symbols: - A list of unique symbols. - Returns: - A dict containing the mapping between symbols and IDs. - """ - return {sym: i for i, sym in enumerate(symbols)} - - -def add_self_loops( - arcs: List[List[Any]], disambig_token: int, disambig_word: int -) -> List[List[Any]]: - """Adds self-loops to states of an FST to propagate disambiguation symbols - through it. They are added on each state with non-epsilon output symbols - on at least one arc out of the state. - - See also fstaddselfloops.pl from Kaldi. One difference is that - Kaldi uses OpenFst style FSTs and it has multiple final states. - This function uses k2 style FSTs and it does not need to add self-loops - to the final state. - - The input label of a self-loop is `disambig_token`, while the output - label is `disambig_word`. - - Args: - arcs: - A list-of-list. The sublist contains - `[src_state, dest_state, label, aux_label, score]` - disambig_token: - It is the token ID of the symbol `#0`. - disambig_word: - It is the word ID of the symbol `#0`. - - Return: - Return new `arcs` containing self-loops. - """ - states_needs_self_loops = set() - for arc in arcs: - src, dst, ilabel, olabel, score = arc - if olabel != 0: - states_needs_self_loops.add(src) - - ans = [] - for s in states_needs_self_loops: - ans.append([s, s, disambig_token, disambig_word, 0]) - - return arcs + ans - - -def lexicon_to_fst( - lexicon: Lexicon, - token2id: Dict[str, int], - word2id: Dict[str, int], - sil_token: str = "SIL", - sil_prob: float = 0.5, - need_self_loops: bool = False, -) -> k2.Fsa: - """Convert a lexicon to an FST (in k2 format) with optional silence at - the beginning and end of each word. - - Args: - lexicon: - The input lexicon. See also :func:`read_lexicon` - token2id: - A dict mapping tokens to IDs. - word2id: - A dict mapping words to IDs. - sil_token: - The silence token. - sil_prob: - The probability for adding a silence at the beginning and end - of the word. - need_self_loops: - If True, add self-loop to states with non-epsilon output symbols - on at least one arc out of the state. The input label for this - self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. - Returns: - Return an instance of `k2.Fsa` representing the given lexicon. - """ - assert sil_prob > 0.0 and sil_prob < 1.0 - # CAUTION: we use score, i.e, negative cost. - sil_score = math.log(sil_prob) - no_sil_score = math.log(1.0 - sil_prob) - - start_state = 0 - loop_state = 1 # words enter and leave from here - sil_state = 2 # words terminate here when followed by silence; this state - # has a silence transition to loop_state. - # the next un-allocated state, will be incremented as we go. - next_state = 3 - arcs = [] - - assert token2id[""] == 0 - assert word2id[""] == 0 - - eps = 0 - - sil_token = token2id[sil_token] - - arcs.append([start_state, loop_state, eps, eps, no_sil_score]) - arcs.append([start_state, sil_state, eps, eps, sil_score]) - arcs.append([sil_state, loop_state, sil_token, eps, 0]) - - for word, tokens in lexicon: - assert len(tokens) > 0, f"{word} has no pronunciations" - cur_state = loop_state - - word = word2id[word] - tokens = [token2id[i] for i in tokens] - - for i in range(len(tokens) - 1): - w = word if i == 0 else eps - arcs.append([cur_state, next_state, tokens[i], w, 0]) - - cur_state = next_state - next_state += 1 - - # now for the last token of this word - # It has two out-going arcs, one to the loop state, - # the other one to the sil_state. - i = len(tokens) - 1 - w = word if i == 0 else eps - arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score]) - arcs.append([cur_state, sil_state, tokens[i], w, sil_score]) - - if need_self_loops: - disambig_token = token2id["#0"] - disambig_word = word2id["#0"] - arcs = add_self_loops( - arcs, - disambig_token=disambig_token, - disambig_word=disambig_word, - ) - - final_state = next_state - arcs.append([loop_state, final_state, -1, -1, 0]) - arcs.append([final_state]) - - arcs = sorted(arcs, key=lambda arc: arc[0]) - arcs = [[str(i) for i in arc] for arc in arcs] - arcs = [" ".join(arc) for arc in arcs] - arcs = "\n".join(arcs) - - fsa = k2.Fsa.from_str(arcs, acceptor=False) - return fsa - - -def main(): - args = get_args() - lang_dir = Path(args.lang_dir) - lexicon_filename = lang_dir / "lexicon.txt" - sil_token = "SIL" - sil_prob = 0.5 - - lexicon = read_lexicon(lexicon_filename) - tokens = get_tokens(lexicon) - words = get_words(lexicon) - - lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) - - for i in range(max_disambig + 1): - disambig = f"#{i}" - assert disambig not in tokens - tokens.append(f"#{i}") - - assert "" not in tokens - tokens = [""] + tokens - - assert "" not in words - assert "#0" not in words - assert "" not in words - assert "" not in words - - words = [""] + words + ["#0", "", ""] - - token2id = generate_id_map(tokens) - word2id = generate_id_map(words) - - write_mapping(lang_dir / "tokens.txt", token2id) - write_mapping(lang_dir / "words.txt", word2id) - write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) - - L = lexicon_to_fst( - lexicon, - token2id=token2id, - word2id=word2id, - sil_token=sil_token, - sil_prob=sil_prob, - ) - - L_disambig = lexicon_to_fst( - lexicon_disambig, - token2id=token2id, - word2id=word2id, - sil_token=sil_token, - sil_prob=sil_prob, - need_self_loops=True, - ) - torch.save(L.as_dict(), lang_dir / "L.pt") - torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") - - if args.debug: - labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt") - aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt") - - L.labels_sym = labels_sym - L.aux_labels_sym = aux_labels_sym - L.draw(f"{lang_dir / 'L.svg'}", title="L.pt") - - L_disambig.labels_sym = labels_sym - L_disambig.aux_labels_sym = aux_labels_sym - L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt") - - -if __name__ == "__main__": - main() diff --git a/egs/iwslt22_ta/ASR/local/prepare_lang.py b/egs/iwslt22_ta/ASR/local/prepare_lang.py new file mode 120000 index 000000000..747f2ab39 --- /dev/null +++ b/egs/iwslt22_ta/ASR/local/prepare_lang.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/local/prepare_lang_bpe.py b/egs/iwslt22_ta/ASR/local/prepare_lang_bpe.py deleted file mode 100755 index 24104581f..000000000 --- a/egs/iwslt22_ta/ASR/local/prepare_lang_bpe.py +++ /dev/null @@ -1,255 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) - -""" - -This script takes as input `lang_dir`, which should contain:: - - - lang_dir/bpe.model, - - lang_dir/words.txt - -and generates the following files in the directory `lang_dir`: - - - lexicon.txt - - lexicon_disambig.txt - - L.pt - - L_disambig.pt - - tokens.txt -""" - -import argparse -from pathlib import Path -from typing import Dict, List, Tuple - -import k2 -import sentencepiece as spm -import torch -from prepare_lang import ( - Lexicon, - add_disambig_symbols, - add_self_loops, - write_lexicon, - write_mapping, -) - -from icefall.utils import str2bool -import pdb - - -def lexicon_to_fst_no_sil( - lexicon: Lexicon, - token2id: Dict[str, int], - word2id: Dict[str, int], - need_self_loops: bool = False, -) -> k2.Fsa: - """Convert a lexicon to an FST (in k2 format). - - Args: - lexicon: - The input lexicon. See also :func:`read_lexicon` - token2id: - A dict mapping tokens to IDs. - word2id: - A dict mapping words to IDs. - need_self_loops: - If True, add self-loop to states with non-epsilon output symbols - on at least one arc out of the state. The input label for this - self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. - Returns: - Return an instance of `k2.Fsa` representing the given lexicon. - """ - loop_state = 0 # words enter and leave from here - next_state = 1 # the next un-allocated state, will be incremented as we go - - arcs = [] - - # The blank symbol is defined in local/train_bpe_model.py - assert token2id[""] == 0 - assert word2id[""] == 0 - - eps = 0 - - for word, pieces in lexicon: - assert len(pieces) > 0, f"{word} has no pronunciations" - cur_state = loop_state - - word = word2id[word] - pieces = [token2id[i] for i in pieces] - - for i in range(len(pieces) - 1): - w = word if i == 0 else eps - arcs.append([cur_state, next_state, pieces[i], w, 0]) - - cur_state = next_state - next_state += 1 - - # now for the last piece of this word - i = len(pieces) - 1 - w = word if i == 0 else eps - arcs.append([cur_state, loop_state, pieces[i], w, 0]) - - if need_self_loops: - disambig_token = token2id["#0"] - disambig_word = word2id["#0"] - arcs = add_self_loops( - arcs, - disambig_token=disambig_token, - disambig_word=disambig_word, - ) - - final_state = next_state - arcs.append([loop_state, final_state, -1, -1, 0]) - arcs.append([final_state]) - - arcs = sorted(arcs, key=lambda arc: arc[0]) - arcs = [[str(i) for i in arc] for arc in arcs] - arcs = [" ".join(arc) for arc in arcs] - arcs = "\n".join(arcs) - - fsa = k2.Fsa.from_str(arcs, acceptor=False) - return fsa - - -def generate_lexicon( - model_file: str, words: List[str] -) -> Tuple[Lexicon, Dict[str, int]]: - """Generate a lexicon from a BPE model. - - Args: - model_file: - Path to a sentencepiece model. - words: - A list of strings representing words. - Returns: - Return a tuple with two elements: - - A dict whose keys are words and values are the corresponding - word pieces. - - A dict representing the token symbol, mapping from tokens to IDs. - """ - sp = spm.SentencePieceProcessor() - sp.load(str(model_file)) - - words_pieces: List[List[str]] = sp.encode(words, out_type=str) - - lexicon = [] - for word, pieces in zip(words, words_pieces): - lexicon.append((word, pieces)) - - # The OOV word is - lexicon.append(("", [sp.id_to_piece(sp.unk_id())])) - - token2id: Dict[str, int] = dict() - for i in range(sp.vocab_size()): - token2id[sp.id_to_piece(i)] = i - - return lexicon, token2id - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", - type=str, - help="""Input and output directory. - It should contain the bpe.model and words.txt - """, - ) - - parser.add_argument( - "--debug", - type=str2bool, - default=False, - help="""True for debugging, which will generate - a visualization of the lexicon FST. - - Caution: If your lexicon contains hundreds of thousands - of lines, please set it to False! - - See "test/test_bpe_lexicon.py" for usage. - """, - ) - - return parser.parse_args() - - -def main(): - args = get_args() - lang_dir = Path(args.lang_dir) - model_file = lang_dir / "bpe.model" - - word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") - - words = word_sym_table.symbols - - excluded = ["", "!SIL", "", "", "#0", "", ""] - for w in excluded: - if w in words: - words.remove(w) - - lexicon, token_sym_table = generate_lexicon(model_file, words) - - lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) - - next_token_id = max(token_sym_table.values()) + 1 - for i in range(max_disambig + 1): - disambig = f"#{i}" - assert disambig not in token_sym_table - token_sym_table[disambig] = next_token_id - next_token_id += 1 - - word_sym_table.add("#0") - word_sym_table.add("") - word_sym_table.add("") - - write_mapping(lang_dir / "tokens.txt", token_sym_table) - - write_lexicon(lang_dir / "lexicon.txt", lexicon) - write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) - - L = lexicon_to_fst_no_sil( - lexicon, - token2id=token_sym_table, - word2id=word_sym_table, - ) - - L_disambig = lexicon_to_fst_no_sil( - lexicon_disambig, - token2id=token_sym_table, - word2id=word_sym_table, - need_self_loops=True, - ) - torch.save(L.as_dict(), lang_dir / "L.pt") - torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") - - if args.debug: - labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt") - aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt") - - L.labels_sym = labels_sym - L.aux_labels_sym = aux_labels_sym - L.draw(f"{lang_dir / 'L.svg'}", title="L.pt") - - L_disambig.labels_sym = labels_sym - L_disambig.aux_labels_sym = aux_labels_sym - L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt") - - -if __name__ == "__main__": - main() diff --git a/egs/iwslt22_ta/ASR/local/prepare_lang_bpe.py b/egs/iwslt22_ta/ASR/local/prepare_lang_bpe.py new file mode 120000 index 000000000..36b40e7fc --- /dev/null +++ b/egs/iwslt22_ta/ASR/local/prepare_lang_bpe.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang_bpe.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/local/prepare_lexicon.py b/egs/iwslt22_ta/ASR/local/prepare_lexicon.py deleted file mode 100755 index 807579503..000000000 --- a/egs/iwslt22_ta/ASR/local/prepare_lexicon.py +++ /dev/null @@ -1,39 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright 2023 Johns Hopkins University (Amir Hussein) -# Apache 2.0 - -# This script prepares givel a column of words lexicon. - -import argparse - - -def get_args(): - parser = argparse.ArgumentParser( - description="""Creates the list of characters and words in lexicon""" - ) - parser.add_argument("input", type=str, help="""Input list of words file""") - parser.add_argument("output", type=str, help="""output graphemic lexicon""") - args = parser.parse_args() - return args - - -def main(): - lex = {} - args = get_args() - with open(args.input, "r", encoding="utf-8") as f: - for line in f: - line = line.strip() - characters = list(line) - characters = " ".join( - ["V" if char == "*" else char for char in characters] - ) - lex[line] = characters - - with open(args.output, "w", encoding="utf-8") as fp: - for key in sorted(lex): - fp.write(key + " " + lex[key] + "\n") - - -if __name__ == "__main__": - main() diff --git a/egs/iwslt22_ta/ASR/local/prepare_transcripts.py b/egs/iwslt22_ta/ASR/local/prepare_transcripts.py deleted file mode 100755 index 2a2e3d56b..000000000 --- a/egs/iwslt22_ta/ASR/local/prepare_transcripts.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2023 Johns Hopkins University (Amir Hussein) - -#!/usr/bin/python -""" -This script prepares transcript_words.txt from cutset -""" - -from lhotse import CutSet -import argparse -import logging -import pdb -from pathlib import Path -import os - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - parser.add_argument( - "--cut", - type=str, - default="", - help="Cutset file", - ) - parser.add_argument( - "--langdir", - type=str, - default="", - help="name of the lang-dir", - ) - return parser - - -def main(): - - parser = get_parser() - args = parser.parse_args() - - logging.info("Reading the cuts") - cuts = CutSet.from_file(args.cut) - langdir = args.langdir - - - if not os.path.exists(langdir): - os.makedirs(langdir) - - with open(langdir / "transcript_words.txt", 'w') as txt: - for c in cuts: - #breakpoint() - txt = c.supervisions[0].text - txt.write(src_txt + '\n') - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/local/prepare_transcripts.py b/egs/iwslt22_ta/ASR/local/prepare_transcripts.py new file mode 120000 index 000000000..4a7e2b1c1 --- /dev/null +++ b/egs/iwslt22_ta/ASR/local/prepare_transcripts.py @@ -0,0 +1 @@ +/exp/ahussein/tmp/icefall/egs/iwslt22_ta/ST/local/prepare_transcripts.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/local/test_prepare_lang.py b/egs/iwslt22_ta/ASR/local/test_prepare_lang.py deleted file mode 100755 index d4cf62bba..000000000 --- a/egs/iwslt22_ta/ASR/local/test_prepare_lang.py +++ /dev/null @@ -1,106 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) - -import os -import tempfile - -import k2 -from prepare_lang import ( - add_disambig_symbols, - generate_id_map, - get_phones, - get_words, - lexicon_to_fst, - read_lexicon, - write_lexicon, - write_mapping, -) - - -def generate_lexicon_file() -> str: - fd, filename = tempfile.mkstemp() - os.close(fd) - s = """ - !SIL SIL - SPN - SPN - f f - a a - foo f o o - bar b a r - bark b a r k - food f o o d - food2 f o o d - fo f o - """.strip() - with open(filename, "w") as f: - f.write(s) - return filename - - -def test_read_lexicon(filename: str): - lexicon = read_lexicon(filename) - phones = get_phones(lexicon) - words = get_words(lexicon) - print(lexicon) - print(phones) - print(words) - lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) - print(lexicon_disambig) - print("max disambig:", f"#{max_disambig}") - - phones = ["", "SIL", "SPN"] + phones - for i in range(max_disambig + 1): - phones.append(f"#{i}") - words = [""] + words - - phone2id = generate_id_map(phones) - word2id = generate_id_map(words) - - print(phone2id) - print(word2id) - - write_mapping("phones.txt", phone2id) - write_mapping("words.txt", word2id) - - write_lexicon("a.txt", lexicon) - write_lexicon("a_disambig.txt", lexicon_disambig) - - fsa = lexicon_to_fst(lexicon, phone2id=phone2id, word2id=word2id) - fsa.labels_sym = k2.SymbolTable.from_file("phones.txt") - fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") - fsa.draw("L.pdf", title="L") - - fsa_disambig = lexicon_to_fst( - lexicon_disambig, phone2id=phone2id, word2id=word2id - ) - fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") - fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") - fsa_disambig.draw("L_disambig.pdf", title="L_disambig") - - -def main(): - filename = generate_lexicon_file() - test_read_lexicon(filename) - os.remove(filename) - - -if __name__ == "__main__": - main() diff --git a/egs/iwslt22_ta/ASR/local/test_prepare_lang.py b/egs/iwslt22_ta/ASR/local/test_prepare_lang.py new file mode 120000 index 000000000..f0f864998 --- /dev/null +++ b/egs/iwslt22_ta/ASR/local/test_prepare_lang.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/test_prepare_lang.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/local/train_bpe_model.py b/egs/iwslt22_ta/ASR/local/train_bpe_model.py deleted file mode 100755 index bc5812810..000000000 --- a/egs/iwslt22_ta/ASR/local/train_bpe_model.py +++ /dev/null @@ -1,98 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# You can install sentencepiece via: -# -# pip install sentencepiece -# -# Due to an issue reported in -# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030 -# -# Please install a version >=0.1.96 - -import argparse -import shutil -from pathlib import Path - -import sentencepiece as spm - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", - type=str, - help="""Input and output directory. - It should contain the training corpus: transcript_words.txt. - The generated bpe.model is saved to this directory. - """, - ) - - parser.add_argument( - "--transcript", - type=str, - help="Training transcript.", - ) - - parser.add_argument( - "--vocab-size", - type=int, - help="Vocabulary size for BPE training", - ) - - return parser.parse_args() - - -def main(): - args = get_args() - vocab_size = args.vocab_size - lang_dir = Path(args.lang_dir) - - model_type = "unigram" - - model_prefix = f"{lang_dir}/{model_type}_{vocab_size}" - train_text = args.transcript - character_coverage = 1.0 - input_sentence_size = 100000000 - - user_defined_symbols = ["", ""] - unk_id = len(user_defined_symbols) - # Note: unk_id is fixed to 2. - # If you change it, you should also change other - # places that are using it. - - model_file = Path(model_prefix + ".model") - if not model_file.is_file(): - spm.SentencePieceTrainer.train( - input=train_text, - vocab_size=vocab_size, - model_type=model_type, - model_prefix=model_prefix, - input_sentence_size=input_sentence_size, - character_coverage=character_coverage, - user_defined_symbols=user_defined_symbols, - unk_id=unk_id, - bos_id=-1, - eos_id=-1, - ) - - shutil.copyfile(model_file, f"{lang_dir}/bpe.model") - - -if __name__ == "__main__": - main() diff --git a/egs/iwslt22_ta/ASR/local/train_bpe_model.py b/egs/iwslt22_ta/ASR/local/train_bpe_model.py new file mode 120000 index 000000000..6fad36421 --- /dev/null +++ b/egs/iwslt22_ta/ASR/local/train_bpe_model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/train_bpe_model.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/local/validate_manifest.py b/egs/iwslt22_ta/ASR/local/validate_manifest.py new file mode 120000 index 000000000..0a9725e87 --- /dev/null +++ b/egs/iwslt22_ta/ASR/local/validate_manifest.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/validate_manifest.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/asr_finetune_datamodule.py b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/asr_finetune_datamodule.py deleted file mode 100644 index 13bc882b3..000000000 --- a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/asr_finetune_datamodule.py +++ /dev/null @@ -1,422 +0,0 @@ -# Copyright 2022 Amir Hussein - -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import inspect -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, Optional - -import torch -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy -from lhotse.dataset import ( - CutConcatenate, - CutMix, - DynamicBucketingSampler, - K2SpeechRecognitionDataset, - PrecomputedFeatures, - SingleCutSampler, - SpecAugment, -) -from lhotse.dataset.input_strategies import OnTheFlyFeatures -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class MGB2AsrDataModule: - - """ - DataModule for k2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/fbank2"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=8, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", - ) - - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir /"cuts_musan.jsonl.gz" - ) - - transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) - ) - else: - logging.info("Disable MUSAN") - - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) - # Set the value of num_frame_masks according to Lhotse's version. - # In different Lhotse's versions, the default of num_frame_masks is - # different. - num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature( - SpecAugment.__init__ - ).parameters["num_frame_masks"] - if num_frame_masks_parameter.default == 1: - num_frame_masks = 2 - logging.info(f"Num frame mask: {num_frame_masks}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=num_frame_masks, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - drop_last=self.args.drop_last, - ) - else: - logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80))), - return_cuts=self.args.return_cuts, - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.debug("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else PrecomputedFeatures(), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts, max_duration=self.args.max_duration, shuffle=False - ) - logging.debug("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") - return load_manifest_lazy( - self.args.manifest_dir / "callhome"/"cuts_teltrain_shuf.jsonl.gz" - ) - - @lru_cache() - def dev_cuts(self) -> CutSet: - logging.info("About to get dev cuts") - return load_manifest_lazy(self.args.manifest_dir / "callhome"/ "cuts_devall.jsonl.gz") - - @lru_cache() - def lev_test_cuts(self) -> CutSet: - logging.info("About to get lev test cuts") - return load_manifest_lazy(self.args.manifest_dir / "callhome" / "cuts_levtest.jsonl.gz") - - @lru_cache() - def iraqi_test_cuts(self) -> CutSet: - logging.info("About to get iraqi test cuts") - return load_manifest_lazy(self.args.manifest_dir / "callhome" / "cuts_iraqitest.jsonl.gz") - - @lru_cache() - def gulf_test_cuts(self) -> CutSet: - logging.info("About to get gukf test cuts") - return load_manifest_lazy(self.args.manifest_dir / "callhome" /"cuts_gulftest.jsonl.gz") - - @lru_cache() - def egy_test_cuts(self) -> CutSet: - logging.info("About to get egy test cuts") - return load_manifest_lazy(self.args.manifest_dir / "callhome" /"cuts_egytest.jsonl.gz") - - @lru_cache() - def egy_sup_cuts(self) -> CutSet: - logging.info("About to get egy sup cuts") - return load_manifest_lazy(self.args.manifest_dir / "callhome" /"cuts_egysup.jsonl.gz") - - @lru_cache() - def egy_h5_cuts(self) -> CutSet: - logging.info("About to get egy h5 cuts") - return load_manifest_lazy(self.args.manifest_dir / "callhome" /"cuts_egyh5.jsonl.gz") - \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/beam_search.py b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/beam_search.py deleted file mode 100644 index 5e9428b60..000000000 --- a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/beam_search.py +++ /dev/null @@ -1,2085 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang -# Xiaoyu Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings -from dataclasses import dataclass, field -from typing import Dict, List, Optional, Tuple, Union - -import k2 -import sentencepiece as spm -import torch -from model import Transducer - -from icefall import NgramLm, NgramLmStateCost -from icefall.decode import Nbest, one_best_decoding -from icefall.rnn_lm.model import RnnLmModel -from icefall.utils import ( - DecodingResults, - add_eos, - add_sos, - get_texts, - get_texts_with_timestamp, -) - - -def fast_beam_search_one_best( - model: Transducer, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """It limits the maximum number of symbols per frame to 1. - - A lattice is first obtained using fast beam search, and then - the shortest path within the lattice is used as the final output. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - temperature=temperature, - ) - - best_path = one_best_decoding(lattice) - - if not return_timestamps: - return get_texts(best_path) - else: - return get_texts_with_timestamp(best_path) - - -def fast_beam_search_nbest_LG( - model: Transducer, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - num_paths: int, - nbest_scale: float = 0.5, - use_double_scores: bool = True, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """It limits the maximum number of symbols per frame to 1. - - The process to get the results is: - - (1) Use fast beam search to get a lattice - - (2) Select `num_paths` paths from the lattice using k2.random_paths() - - (3) Unique the selected paths - - (4) Intersect the selected paths with the lattice and compute the - shortest path from the intersection result - - (5) The path with the largest score is used as the decoding output. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - num_paths: - Number of paths to extract from the decoded lattice. - nbest_scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - use_double_scores: - True to use double precision for computation. False to use - single precision. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - temperature=temperature, - ) - - nbest = Nbest.from_lattice( - lattice=lattice, - num_paths=num_paths, - use_double_scores=use_double_scores, - nbest_scale=nbest_scale, - ) - - # The following code is modified from nbest.intersect() - word_fsa = k2.invert(nbest.fsa) - if hasattr(lattice, "aux_labels"): - # delete token IDs as it is not needed - del word_fsa.aux_labels - word_fsa.scores.zero_() - word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa) - path_to_utt_map = nbest.shape.row_ids(1) - - if hasattr(lattice, "aux_labels"): - # lattice has token IDs as labels and word IDs as aux_labels. - # inv_lattice has word IDs as labels and token IDs as aux_labels - inv_lattice = k2.invert(lattice) - inv_lattice = k2.arc_sort(inv_lattice) - else: - inv_lattice = k2.arc_sort(lattice) - - if inv_lattice.shape[0] == 1: - path_lattice = k2.intersect_device( - inv_lattice, - word_fsa_with_epsilon_loops, - b_to_a_map=torch.zeros_like(path_to_utt_map), - sorted_match_a=True, - ) - else: - path_lattice = k2.intersect_device( - inv_lattice, - word_fsa_with_epsilon_loops, - b_to_a_map=path_to_utt_map, - sorted_match_a=True, - ) - - # path_lattice has word IDs as labels and token IDs as aux_labels - path_lattice = k2.top_sort(k2.connect(path_lattice)) - tot_scores = path_lattice.get_tot_scores( - use_double_scores=use_double_scores, - log_semiring=True, # Note: we always use True - ) - # See https://github.com/k2-fsa/icefall/pull/420 for why - # we always use log_semiring=True - - ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) - best_hyp_indexes = ragged_tot_scores.argmax() - best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes) - - if not return_timestamps: - return get_texts(best_path) - else: - return get_texts_with_timestamp(best_path) - - -def fast_beam_search_nbest( - model: Transducer, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - num_paths: int, - nbest_scale: float = 0.5, - use_double_scores: bool = True, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """It limits the maximum number of symbols per frame to 1. - - The process to get the results is: - - (1) Use fast beam search to get a lattice - - (2) Select `num_paths` paths from the lattice using k2.random_paths() - - (3) Unique the selected paths - - (4) Intersect the selected paths with the lattice and compute the - shortest path from the intersection result - - (5) The path with the largest score is used as the decoding output. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - num_paths: - Number of paths to extract from the decoded lattice. - nbest_scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - use_double_scores: - True to use double precision for computation. False to use - single precision. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - temperature=temperature, - ) - - nbest = Nbest.from_lattice( - lattice=lattice, - num_paths=num_paths, - use_double_scores=use_double_scores, - nbest_scale=nbest_scale, - ) - - # at this point, nbest.fsa.scores are all zeros. - - nbest = nbest.intersect(lattice) - # Now nbest.fsa.scores contains acoustic scores - - max_indexes = nbest.tot_scores().argmax() - - best_path = k2.index_fsa(nbest.fsa, max_indexes) - - if not return_timestamps: - return get_texts(best_path) - else: - return get_texts_with_timestamp(best_path) - - -def fast_beam_search_nbest_oracle( - model: Transducer, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - num_paths: int, - ref_texts: List[List[int]], - use_double_scores: bool = True, - nbest_scale: float = 0.5, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """It limits the maximum number of symbols per frame to 1. - - A lattice is first obtained using fast beam search, and then - we select `num_paths` linear paths from the lattice. The path - that has the minimum edit distance with the given reference transcript - is used as the output. - - This is the best result we can achieve for any nbest based rescoring - methods. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - num_paths: - Number of paths to extract from the decoded lattice. - ref_texts: - A list-of-list of integers containing the reference transcripts. - If the decoding_graph is a trivial_graph, the integer ID is the - BPE token ID. - use_double_scores: - True to use double precision for computation. False to use - single precision. - nbest_scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - temperature=temperature, - ) - - nbest = Nbest.from_lattice( - lattice=lattice, - num_paths=num_paths, - use_double_scores=use_double_scores, - nbest_scale=nbest_scale, - ) - - hyps = nbest.build_levenshtein_graphs() - refs = k2.levenshtein_graph(ref_texts, device=hyps.device) - - levenshtein_alignment = k2.levenshtein_alignment( - refs=refs, - hyps=hyps, - hyp_to_ref_map=nbest.shape.row_ids(1), - sorted_match_ref=True, - ) - - tot_scores = levenshtein_alignment.get_tot_scores( - use_double_scores=False, log_semiring=False - ) - ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) - - max_indexes = ragged_tot_scores.argmax() - - best_path = k2.index_fsa(nbest.fsa, max_indexes) - - if not return_timestamps: - return get_texts(best_path) - else: - return get_texts_with_timestamp(best_path) - - -def fast_beam_search( - model: Transducer, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - temperature: float = 1.0, -) -> k2.Fsa: - """It limits the maximum number of symbols per frame to 1. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - temperature: - Softmax temperature. - Returns: - Return an FsaVec with axes [utt][state][arc] containing the decoded - lattice. Note: When the input graph is a TrivialGraph, the returned - lattice is actually an acceptor. - """ - assert encoder_out.ndim == 3 - - context_size = model.decoder.context_size - vocab_size = model.decoder.vocab_size - - B, T, C = encoder_out.shape - - config = k2.RnntDecodingConfig( - vocab_size=vocab_size, - decoder_history_len=context_size, - beam=beam, - max_contexts=max_contexts, - max_states=max_states, - ) - individual_streams = [] - for i in range(B): - individual_streams.append(k2.RnntDecodingStream(decoding_graph)) - decoding_streams = k2.RnntDecodingStreams(individual_streams, config) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - for t in range(T): - # shape is a RaggedShape of shape (B, context) - # contexts is a Tensor of shape (shape.NumElements(), context_size) - shape, contexts = decoding_streams.get_contexts() - # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 - contexts = contexts.to(torch.int64) - # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) - decoder_out = model.decoder(contexts, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - # current_encoder_out is of shape - # (shape.NumElements(), 1, joiner_dim) - # fmt: off - current_encoder_out = torch.index_select( - encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) - ) - # fmt: on - logits = model.joiner( - current_encoder_out.unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - logits = logits.squeeze(1).squeeze(1) - log_probs = (logits / temperature).log_softmax(dim=-1) - decoding_streams.advance(log_probs) - decoding_streams.terminate_and_flush_to_streams() - lattice = decoding_streams.format_output(encoder_out_lens.tolist()) - - return lattice - - -def greedy_search( - model: Transducer, - encoder_out: torch.Tensor, - max_sym_per_frame: int, - return_timestamps: bool = False, -) -> Union[List[int], DecodingResults]: - """Greedy search for a single utterance. - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - max_sym_per_frame: - Maximum number of symbols per frame. If it is set to 0, the WER - would be 100%. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - unk_id = getattr(model, "unk_id", blank_id) - - device = next(model.parameters()).device - - decoder_input = torch.tensor( - [-1] * (context_size - 1) + [blank_id], device=device, dtype=torch.int64 - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - T = encoder_out.size(1) - t = 0 - hyp = [blank_id] * context_size - - # timestamp[i] is the frame index after subsampling - # on which hyp[i] is decoded - timestamp = [] - - # Maximum symbols per utterance. - max_sym_per_utt = 1000 - - # symbols per frame - sym_per_frame = 0 - - # symbols per utterance decoded so far - sym_per_utt = 0 - - while t < T and sym_per_utt < max_sym_per_utt: - if sym_per_frame >= max_sym_per_frame: - sym_per_frame = 0 - t += 1 - continue - - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) - # fmt: on - logits = model.joiner( - current_encoder_out, decoder_out.unsqueeze(1), project_input=False - ) - # logits is (1, 1, 1, vocab_size) - - y = logits.argmax().item() - if y not in (blank_id, unk_id): - hyp.append(y) - timestamp.append(t) - decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( - 1, context_size - ) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - sym_per_utt += 1 - sym_per_frame += 1 - else: - sym_per_frame = 0 - t += 1 - hyp = hyp[context_size:] # remove blanks - - if not return_timestamps: - return hyp - else: - return DecodingResults( - hyps=[hyp], - timestamps=[timestamp], - ) - - -def greedy_search_batch( - model: Transducer, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C), where N >= 1. - encoder_out_lens: - A 1-D tensor of shape (N,), containing number of valid frames in - encoder_out before padding. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - assert encoder_out.ndim == 3 - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - device = next(model.parameters()).device - - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - hyps = [[-1] * (context_size - 1) + [blank_id] for _ in range(N)] - - # timestamp[n][i] is the frame index after subsampling - # on which hyp[n][i] is decoded - timestamps = [[] for _ in range(N)] - - decoder_input = torch.tensor( - hyps, - device=device, - dtype=torch.int64, - ) # (N, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out: (N, 1, decoder_out_dim) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - for (t, batch_size) in enumerate(batch_size_list): - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) - offset = end - - decoder_out = decoder_out[:batch_size] - - logits = model.joiner( - current_encoder_out, decoder_out.unsqueeze(1), project_input=False - ) - # logits'shape (batch_size, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v not in (blank_id, unk_id): - hyps[i].append(v) - timestamps[i].append(t) - emitted = True - if emitted: - # update decoder output - decoder_input = [h[-context_size:] for h in hyps[:batch_size]] - decoder_input = torch.tensor( - decoder_input, - device=device, - dtype=torch.int64, - ) - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - sorted_ans = [h[context_size:] for h in hyps] - ans = [] - ans_timestamps = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - ans_timestamps.append(timestamps[unsorted_indices[i]]) - - if not return_timestamps: - return ans - else: - return DecodingResults( - hyps=ans, - timestamps=ans_timestamps, - ) - - -@dataclass -class Hypothesis: - # The predicted tokens so far. - # Newly predicted tokens are appended to `ys`. - ys: List[int] - - # The log prob of ys. - # It contains only one entry. - log_prob: torch.Tensor - - # timestamp[i] is the frame index after subsampling - # on which ys[i] is decoded - timestamp: List[int] = field(default_factory=list) - - # the lm score for next token given the current ys - lm_score: Optional[torch.Tensor] = None - - # the RNNLM states (h and c in LSTM) - state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None - - # N-gram LM state - state_cost: Optional[NgramLmStateCost] = None - - @property - def key(self) -> str: - """Return a string representation of self.ys""" - return "_".join(map(str, self.ys)) - - -class HypothesisList(object): - def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None: - """ - Args: - data: - A dict of Hypotheses. Its key is its `value.key`. - """ - if data is None: - self._data = {} - else: - self._data = data - - @property - def data(self) -> Dict[str, Hypothesis]: - return self._data - - def add(self, hyp: Hypothesis) -> None: - """Add a Hypothesis to `self`. - - If `hyp` already exists in `self`, its probability is updated using - `log-sum-exp` with the existed one. - - Args: - hyp: - The hypothesis to be added. - """ - key = hyp.key - if key in self: - old_hyp = self._data[key] # shallow copy - torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) - else: - self._data[key] = hyp - - def get_most_probable(self, length_norm: bool = False) -> Hypothesis: - """Get the most probable hypothesis, i.e., the one with - the largest `log_prob`. - - Args: - length_norm: - If True, the `log_prob` of a hypothesis is normalized by the - number of tokens in it. - Returns: - Return the hypothesis that has the largest `log_prob`. - """ - if length_norm: - return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) - else: - return max(self._data.values(), key=lambda hyp: hyp.log_prob) - - def remove(self, hyp: Hypothesis) -> None: - """Remove a given hypothesis. - - Caution: - `self` is modified **in-place**. - - Args: - hyp: - The hypothesis to be removed from `self`. - Note: It must be contained in `self`. Otherwise, - an exception is raised. - """ - key = hyp.key - assert key in self, f"{key} does not exist" - del self._data[key] - - def filter(self, threshold: torch.Tensor) -> "HypothesisList": - """Remove all Hypotheses whose log_prob is less than threshold. - - Caution: - `self` is not modified. Instead, a new HypothesisList is returned. - - Returns: - Return a new HypothesisList containing all hypotheses from `self` - with `log_prob` being greater than the given `threshold`. - """ - ans = HypothesisList() - for _, hyp in self._data.items(): - if hyp.log_prob > threshold: - ans.add(hyp) # shallow copy - return ans - - def topk(self, k: int) -> "HypothesisList": - """Return the top-k hypothesis.""" - hyps = list(self._data.items()) - - hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] - - ans = HypothesisList(dict(hyps)) - return ans - - def __contains__(self, key: str): - return key in self._data - - def __iter__(self): - return iter(self._data.values()) - - def __len__(self) -> int: - return len(self._data) - - def __str__(self) -> str: - s = [] - for key in self: - s.append(key) - return ", ".join(s) - - -def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: - """Return a ragged shape with axes [utt][num_hyps]. - - Args: - hyps: - len(hyps) == batch_size. It contains the current hypothesis for - each utterance in the batch. - Returns: - Return a ragged shape with 2 axes [utt][num_hyps]. Note that - the shape is on CPU. - """ - num_hyps = [len(h) for h in hyps] - - # torch.cumsum() is inclusive sum, so we put a 0 at the beginning - # to get exclusive sum later. - num_hyps.insert(0, 0) - - num_hyps = torch.tensor(num_hyps) - row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32) - ans = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=row_splits[-1].item() - ) - return ans - - -def modified_beam_search( - model: Transducer, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: int = 4, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. - - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C). - encoder_out_lens: - A 1-D tensor of shape (N,), containing number of valid frames in - encoder_out before padding. - beam: - Number of active paths during the beam search. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - device = next(model.parameters()).device - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - B = [HypothesisList() for _ in range(N)] - for i in range(N): - B[i].add( - Hypothesis( - ys=[blank_id] * context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - timestamp=[], - ) - ) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - finalized_B = [] - for (t, batch_size) in enumerate(batch_size_list): - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - offset = end - - finalized_B = B[batch_size:] + finalized_B - B = B[:batch_size] - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.cat( - [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] - ) # (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) - - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor - # as index, so we use `to(torch.int64)` below. - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - - log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_ys = hyp.ys[:] - new_token = topk_token_indexes[k] - new_timestamp = hyp.timestamp[:] - if new_token not in (blank_id, unk_id): - new_ys.append(new_token) - new_timestamp.append(t) - - new_log_prob = topk_log_probs[k] - new_hyp = Hypothesis( - ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp - ) - B[i].add(new_hyp) - - B = B + finalized_B - best_hyps = [b.get_most_probable(length_norm=True) for b in B] - - sorted_ans = [h.ys[context_size:] for h in best_hyps] - sorted_timestamps = [h.timestamp for h in best_hyps] - ans = [] - ans_timestamps = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) - - if not return_timestamps: - return ans - else: - return DecodingResults( - hyps=ans, - timestamps=ans_timestamps, - ) - - -def _deprecated_modified_beam_search( - model: Transducer, - encoder_out: torch.Tensor, - beam: int = 4, - return_timestamps: bool = False, -) -> Union[List[int], DecodingResults]: - """It limits the maximum number of symbols per frame to 1. - - It decodes only one utterance at a time. We keep it only for reference. - The function :func:`modified_beam_search` should be preferred as it - supports batch decoding. - - - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - beam: - Beam size. - return_timestamps: - Whether to return timestamps. - - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - - device = next(model.parameters()).device - - T = encoder_out.size(1) - - B = HypothesisList() - B.add( - Hypothesis( - ys=[blank_id] * context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - timestamp=[], - ) - ) - encoder_out = model.joiner.encoder_proj(encoder_out) - - for t in range(T): - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) - # current_encoder_out is of shape (1, 1, 1, encoder_out_dim) - # fmt: on - A = list(B) - B = HypothesisList() - - ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A]) - # ys_log_probs is of shape (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyp in A], - device=device, - dtype=torch.int64, - ) - # decoder_input is of shape (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_output is of shape (num_hyps, 1, 1, joiner_dim) - - current_encoder_out = current_encoder_out.expand( - decoder_out.size(0), 1, 1, -1 - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) - # logits is of shape (num_hyps, 1, 1, vocab_size) - logits = logits.squeeze(1).squeeze(1) - - # now logits is of shape (num_hyps, vocab_size) - log_probs = logits.log_softmax(dim=-1) - - log_probs.add_(ys_log_probs) - - log_probs = log_probs.reshape(-1) - topk_log_probs, topk_indexes = log_probs.topk(beam) - - # topk_hyp_indexes are indexes into `A` - topk_hyp_indexes = topk_indexes // logits.size(-1) - topk_token_indexes = topk_indexes % logits.size(-1) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = topk_hyp_indexes.tolist() - topk_token_indexes = topk_token_indexes.tolist() - - for i in range(len(topk_hyp_indexes)): - hyp = A[topk_hyp_indexes[i]] - new_ys = hyp.ys[:] - new_timestamp = hyp.timestamp[:] - new_token = topk_token_indexes[i] - if new_token not in (blank_id, unk_id): - new_ys.append(new_token) - new_timestamp.append(t) - new_log_prob = topk_log_probs[i] - new_hyp = Hypothesis( - ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp - ) - B.add(new_hyp) - - best_hyp = B.get_most_probable(length_norm=True) - ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks - - if not return_timestamps: - return ys - else: - return DecodingResults(hyps=[ys], timestamps=[best_hyp.timestamp]) - - -def beam_search( - model: Transducer, - encoder_out: torch.Tensor, - beam: int = 4, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[int], DecodingResults]: - """ - It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf - - espnet/nets/beam_search_transducer.py#L247 is used as a reference. - - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - beam: - Beam size. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - - device = next(model.parameters()).device - - decoder_input = torch.tensor( - [blank_id] * context_size, - device=device, - dtype=torch.int64, - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - T = encoder_out.size(1) - t = 0 - - B = HypothesisList() - B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0, timestamp=[])) - - max_sym_per_utt = 20000 - - sym_per_utt = 0 - - decoder_cache: Dict[str, torch.Tensor] = {} - - while t < T and sym_per_utt < max_sym_per_utt: - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) - # fmt: on - A = B - B = HypothesisList() - - joint_cache: Dict[str, torch.Tensor] = {} - - # TODO(fangjun): Implement prefix search to update the `log_prob` - # of hypotheses in A - - while True: - y_star = A.get_most_probable() - A.remove(y_star) - - cached_key = y_star.key - - if cached_key not in decoder_cache: - decoder_input = torch.tensor( - [y_star.ys[-context_size:]], - device=device, - dtype=torch.int64, - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - decoder_cache[cached_key] = decoder_out - else: - decoder_out = decoder_cache[cached_key] - - cached_key += f"-t-{t}" - if cached_key not in joint_cache: - logits = model.joiner( - current_encoder_out, - decoder_out.unsqueeze(1), - project_input=False, - ) - - # TODO(fangjun): Scale the blank posterior - log_prob = (logits / temperature).log_softmax(dim=-1) - # log_prob is (1, 1, 1, vocab_size) - log_prob = log_prob.squeeze() - # Now log_prob is (vocab_size,) - joint_cache[cached_key] = log_prob - else: - log_prob = joint_cache[cached_key] - - # First, process the blank symbol - skip_log_prob = log_prob[blank_id] - new_y_star_log_prob = y_star.log_prob + skip_log_prob - - # ys[:] returns a copy of ys - B.add( - Hypothesis( - ys=y_star.ys[:], - log_prob=new_y_star_log_prob, - timestamp=y_star.timestamp[:], - ) - ) - - # Second, process other non-blank labels - values, indices = log_prob.topk(beam + 1) - for i, v in zip(indices.tolist(), values.tolist()): - if i in (blank_id, unk_id): - continue - new_ys = y_star.ys + [i] - new_log_prob = y_star.log_prob + v - new_timestamp = y_star.timestamp + [t] - A.add( - Hypothesis( - ys=new_ys, - log_prob=new_log_prob, - timestamp=new_timestamp, - ) - ) - - # Check whether B contains more than "beam" elements more probable - # than the most probable in A - A_most_probable = A.get_most_probable() - - kept_B = B.filter(A_most_probable.log_prob) - - if len(kept_B) >= beam: - B = kept_B.topk(beam) - break - - t += 1 - - best_hyp = B.get_most_probable(length_norm=True) - ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks - - if not return_timestamps: - return ys - else: - return DecodingResults(hyps=[ys], timestamps=[best_hyp.timestamp]) - - -def fast_beam_search_with_nbest_rescoring( - model: Transducer, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - ngram_lm_scale_list: List[float], - num_paths: int, - G: k2.Fsa, - sp: spm.SentencePieceProcessor, - word_table: k2.SymbolTable, - oov_word: str = "", - use_double_scores: bool = True, - nbest_scale: float = 0.5, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Dict[str, Union[List[List[int]], DecodingResults]]: - """It limits the maximum number of symbols per frame to 1. - A lattice is first obtained using fast beam search, num_path are selected - and rescored using a given language model. The shortest path within the - lattice is used as the final output. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - ngram_lm_scale_list: - A list of floats representing LM score scales. - num_paths: - Number of paths to extract from the decoded lattice. - G: - An FsaVec containing only a single FSA. It is an n-gram LM. - sp: - The BPE model. - word_table: - The word symbol table. - oov_word: - OOV words are replaced with this word. - use_double_scores: - True to use double precision for computation. False to use - single precision. - nbest_scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - Return the decoded result in a dict, where the key has the form - 'ngram_lm_scale_xx' and the value is the decoded results - optionally with timestamps. `xx` is the ngram LM scale value - used during decoding, i.e., 0.1. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - temperature=temperature, - ) - - nbest = Nbest.from_lattice( - lattice=lattice, - num_paths=num_paths, - use_double_scores=use_double_scores, - nbest_scale=nbest_scale, - ) - # at this point, nbest.fsa.scores are all zeros. - - nbest = nbest.intersect(lattice) - # Now nbest.fsa.scores contains acoustic scores - - am_scores = nbest.tot_scores() - - # Now we need to compute the LM scores of each path. - # (1) Get the token IDs of each Path. We assume the decoding_graph - # is an acceptor, i.e., lattice is also an acceptor - tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) # [path][arc] - - tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.labels.contiguous()) - tokens = tokens.remove_values_leq(0) # remove -1 and 0 - - token_list: List[List[int]] = tokens.tolist() - word_list: List[List[str]] = sp.decode(token_list) - - assert isinstance(oov_word, str), oov_word - assert oov_word in word_table, oov_word - oov_word_id = word_table[oov_word] - - word_ids_list: List[List[int]] = [] - - for words in word_list: - this_word_ids = [] - for w in words.split(): - if w in word_table: - this_word_ids.append(word_table[w]) - else: - this_word_ids.append(oov_word_id) - word_ids_list.append(this_word_ids) - - word_fsas = k2.linear_fsa(word_ids_list, device=lattice.device) - word_fsas_with_self_loops = k2.add_epsilon_self_loops(word_fsas) - - num_unique_paths = len(word_ids_list) - - b_to_a_map = torch.zeros( - num_unique_paths, - dtype=torch.int32, - device=lattice.device, - ) - - rescored_word_fsas = k2.intersect_device( - a_fsas=G, - b_fsas=word_fsas_with_self_loops, - b_to_a_map=b_to_a_map, - sorted_match_a=True, - ret_arc_maps=False, - ) - - rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas) - rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas)) - ngram_lm_scores = rescored_word_fsas.get_tot_scores( - use_double_scores=True, - log_semiring=False, - ) - - ans: Dict[str, Union[List[List[int]], DecodingResults]] = {} - for s in ngram_lm_scale_list: - key = f"ngram_lm_scale_{s}" - tot_scores = am_scores.values + s * ngram_lm_scores - ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) - max_indexes = ragged_tot_scores.argmax() - best_path = k2.index_fsa(nbest.fsa, max_indexes) - - if not return_timestamps: - ans[key] = get_texts(best_path) - else: - ans[key] = get_texts_with_timestamp(best_path) - - return ans - - -def fast_beam_search_with_nbest_rnn_rescoring( - model: Transducer, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - ngram_lm_scale_list: List[float], - num_paths: int, - G: k2.Fsa, - sp: spm.SentencePieceProcessor, - word_table: k2.SymbolTable, - rnn_lm_model: torch.nn.Module, - rnn_lm_scale_list: List[float], - oov_word: str = "", - use_double_scores: bool = True, - nbest_scale: float = 0.5, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Dict[str, Union[List[List[int]], DecodingResults]]: - """It limits the maximum number of symbols per frame to 1. - A lattice is first obtained using fast beam search, num_path are selected - and rescored using a given language model and a rnn-lm. - The shortest path within the lattice is used as the final output. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - ngram_lm_scale_list: - A list of floats representing LM score scales. - num_paths: - Number of paths to extract from the decoded lattice. - G: - An FsaVec containing only a single FSA. It is an n-gram LM. - sp: - The BPE model. - word_table: - The word symbol table. - rnn_lm_model: - A rnn-lm model used for LM rescoring - rnn_lm_scale_list: - A list of floats representing RNN score scales. - oov_word: - OOV words are replaced with this word. - use_double_scores: - True to use double precision for computation. False to use - single precision. - nbest_scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - Return the decoded result in a dict, where the key has the form - 'ngram_lm_scale_xx' and the value is the decoded results - optionally with timestamps. `xx` is the ngram LM scale value - used during decoding, i.e., 0.1. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - temperature=temperature, - ) - - nbest = Nbest.from_lattice( - lattice=lattice, - num_paths=num_paths, - use_double_scores=use_double_scores, - nbest_scale=nbest_scale, - ) - # at this point, nbest.fsa.scores are all zeros. - - nbest = nbest.intersect(lattice) - # Now nbest.fsa.scores contains acoustic scores - - am_scores = nbest.tot_scores() - - # Now we need to compute the LM scores of each path. - # (1) Get the token IDs of each Path. We assume the decoding_graph - # is an acceptor, i.e., lattice is also an acceptor - tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) # [path][arc] - - tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.labels.contiguous()) - tokens = tokens.remove_values_leq(0) # remove -1 and 0 - - token_list: List[List[int]] = tokens.tolist() - word_list: List[List[str]] = sp.decode(token_list) - - assert isinstance(oov_word, str), oov_word - assert oov_word in word_table, oov_word - oov_word_id = word_table[oov_word] - - word_ids_list: List[List[int]] = [] - - for words in word_list: - this_word_ids = [] - for w in words.split(): - if w in word_table: - this_word_ids.append(word_table[w]) - else: - this_word_ids.append(oov_word_id) - word_ids_list.append(this_word_ids) - - word_fsas = k2.linear_fsa(word_ids_list, device=lattice.device) - word_fsas_with_self_loops = k2.add_epsilon_self_loops(word_fsas) - - num_unique_paths = len(word_ids_list) - - b_to_a_map = torch.zeros( - num_unique_paths, - dtype=torch.int32, - device=lattice.device, - ) - - rescored_word_fsas = k2.intersect_device( - a_fsas=G, - b_fsas=word_fsas_with_self_loops, - b_to_a_map=b_to_a_map, - sorted_match_a=True, - ret_arc_maps=False, - ) - - rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas) - rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas)) - ngram_lm_scores = rescored_word_fsas.get_tot_scores( - use_double_scores=True, - log_semiring=False, - ) - - # Now RNN-LM - blank_id = model.decoder.blank_id - sos_id = sp.piece_to_id("sos_id") - eos_id = sp.piece_to_id("eos_id") - - sos_tokens = add_sos(tokens, sos_id) - tokens_eos = add_eos(tokens, eos_id) - sos_tokens_row_splits = sos_tokens.shape.row_splits(1) - sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1] - - x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id) - y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id) - - x_tokens = x_tokens.to(torch.int64) - y_tokens = y_tokens.to(torch.int64) - sentence_lengths = sentence_lengths.to(torch.int64) - - rnn_lm_nll = rnn_lm_model(x=x_tokens, y=y_tokens, lengths=sentence_lengths) - assert rnn_lm_nll.ndim == 2 - assert rnn_lm_nll.shape[0] == len(token_list) - rnn_lm_scores = -1 * rnn_lm_nll.sum(dim=1) - - ans: Dict[str, List[List[int]]] = {} - for n_scale in ngram_lm_scale_list: - for rnn_scale in rnn_lm_scale_list: - key = f"ngram_lm_scale_{n_scale}_rnn_lm_scale_{rnn_scale}" - tot_scores = ( - am_scores.values + n_scale * ngram_lm_scores + rnn_scale * rnn_lm_scores - ) - ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) - max_indexes = ragged_tot_scores.argmax() - best_path = k2.index_fsa(nbest.fsa, max_indexes) - - if not return_timestamps: - ans[key] = get_texts(best_path) - else: - ans[key] = get_texts_with_timestamp(best_path) - - return ans - - -def modified_beam_search_ngram_rescoring( - model: Transducer, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - ngram_lm: NgramLm, - ngram_lm_scale: float, - beam: int = 4, - temperature: float = 1.0, -) -> List[List[int]]: - """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. - - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C). - encoder_out_lens: - A 1-D tensor of shape (N,), containing number of valid frames in - encoder_out before padding. - beam: - Number of active paths during the beam search. - temperature: - Softmax temperature. - Returns: - Return a list-of-list of token IDs. ans[i] is the decoding results - for the i-th utterance. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - device = next(model.parameters()).device - lm_scale = ngram_lm_scale - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - B = [HypothesisList() for _ in range(N)] - for i in range(N): - B[i].add( - Hypothesis( - ys=[blank_id] * context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - state_cost=NgramLmStateCost(ngram_lm), - ) - ) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - finalized_B = [] - for batch_size in batch_size_list: - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - offset = end - - finalized_B = B[batch_size:] + finalized_B - B = B[:batch_size] - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.cat( - [ - hyp.log_prob.reshape(1, 1) + hyp.state_cost.lm_score * lm_scale - for hyps in A - for hyp in hyps - ] - ) # (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) - - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor - # as index, so we use `to(torch.int64)` below. - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - - log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - vocab_size = log_probs.size(-1) - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_ys = hyp.ys[:] - new_token = topk_token_indexes[k] - if new_token not in (blank_id, unk_id): - new_ys.append(new_token) - state_cost = hyp.state_cost.forward_one_step(new_token) - else: - state_cost = hyp.state_cost - - # We only keep AM scores in new_hyp.log_prob - new_log_prob = topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale - - new_hyp = Hypothesis( - ys=new_ys, log_prob=new_log_prob, state_cost=state_cost - ) - B[i].add(new_hyp) - - B = B + finalized_B - best_hyps = [b.get_most_probable(length_norm=True) for b in B] - - sorted_ans = [h.ys[context_size:] for h in best_hyps] - ans = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - - return ans - - -def modified_beam_search_rnnlm_shallow_fusion( - model: Transducer, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - sp: spm.SentencePieceProcessor, - rnnlm: RnnLmModel, - rnnlm_scale: float, - beam: int = 4, - return_timestamps: bool = False, -) -> List[List[int]]: - """Modified_beam_search + RNNLM shallow fusion - - Args: - model (Transducer): - The transducer model - encoder_out (torch.Tensor): - Encoder output in (N,T,C) - encoder_out_lens (torch.Tensor): - A 1-D tensor of shape (N,), containing the number of - valid frames in encoder_out before padding. - sp: - Sentence piece generator. - rnnlm (RnnLmModel): - RNNLM - rnnlm_scale (float): - scale of RNNLM in shallow fusion - beam (int, optional): - Beam size. Defaults to 4. - - Returns: - Return a list-of-list of token IDs. ans[i] is the decoding results - for the i-th utterance. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - assert rnnlm is not None - lm_scale = rnnlm_scale - vocab_size = rnnlm.vocab_size - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = model.decoder.blank_id - sos_id = sp.piece_to_id("") - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - device = next(model.parameters()).device - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - # get initial lm score and lm state by scoring the "sos" token - sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) - init_score, init_states = rnnlm.score_token(sos_token) - - B = [HypothesisList() for _ in range(N)] - for i in range(N): - B[i].add( - Hypothesis( - ys=[blank_id] * context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - state=init_states, - lm_score=init_score.reshape(-1), - timestamp=[], - ) - ) - - rnnlm.clean_cache() - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - finalized_B = [] - for (t, batch_size) in enumerate(batch_size_list): - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] # get batch - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - offset = end - - finalized_B = B[batch_size:] + finalized_B - B = B[:batch_size] - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.cat( - [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] - ) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - - log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - """ - for all hyps with a non-blank new token, score this token. - It is a little confusing here because this for-loop - looks very similar to the one below. Here, we go through all - top-k tokens and only add the non-blanks ones to the token_list. - The RNNLM will score those tokens given the LM states. Note that - the variable `scores` is the LM score after seeing the new - non-blank token. - """ - token_list = [] - hs = [] - cs = [] - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_token = topk_token_indexes[k] - if new_token not in (blank_id, unk_id): - assert new_token != 0, new_token - token_list.append([new_token]) - # store the LSTM states - hs.append(hyp.state[0]) - cs.append(hyp.state[1]) - - # forward RNNLM to get new states and scores - if len(token_list) != 0: - tokens_to_score = ( - torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) - ) - - hs = torch.cat(hs, dim=1).to(device) - cs = torch.cat(cs, dim=1).to(device) - scores, lm_states = rnnlm.score_token(tokens_to_score, (hs, cs)) - - count = 0 # index, used to locate score and lm states - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - ys = hyp.ys[:] - - lm_score = hyp.lm_score - state = hyp.state - - hyp_log_prob = topk_log_probs[k] # get score of current hyp - new_token = topk_token_indexes[k] - new_timestamp = hyp.timestamp[:] - if new_token not in (blank_id, unk_id): - - ys.append(new_token) - new_timestamp.append(t) - hyp_log_prob += lm_score[new_token] * lm_scale # add the lm score - - lm_score = scores[count] - state = ( - lm_states[0][:, count, :].unsqueeze(1), - lm_states[1][:, count, :].unsqueeze(1), - ) - count += 1 - - new_hyp = Hypothesis( - ys=ys, - log_prob=hyp_log_prob, - state=state, - lm_score=lm_score, - timestamp=new_timestamp, - ) - B[i].add(new_hyp) - - B = B + finalized_B - best_hyps = [b.get_most_probable(length_norm=True) for b in B] - - sorted_ans = [h.ys[context_size:] for h in best_hyps] - sorted_timestamps = [h.timestamp for h in best_hyps] - ans = [] - ans_timestamps = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) - - if not return_timestamps: - return ans - else: - return DecodingResults( - tokens=ans, - timestamps=ans_timestamps, - ) diff --git a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/beam_search.py b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/beam_search.py new file mode 120000 index 000000000..e24eca39f --- /dev/null +++ b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/beam_search_old.py b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/beam_search_old.py deleted file mode 100644 index ce8b04afd..000000000 --- a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/beam_search_old.py +++ /dev/null @@ -1,977 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings -from dataclasses import dataclass -from typing import Dict, List, Optional - -import k2 -import torch -from model import Transducer - -from icefall.decode import Nbest, one_best_decoding -from icefall.utils import get_texts - - -def fast_beam_search_one_best( - model: Transducer, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, -) -> List[List[int]]: - """It limits the maximum number of symbols per frame to 1. - - A lattice is first obtained using modified beam search, and then - the shortest path within the lattice is used as the final output. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a HLG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - Returns: - Return the decoded result. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - ) - - best_path = one_best_decoding(lattice) - hyps = get_texts(best_path) - return hyps - - -def fast_beam_search_nbest_oracle( - model: Transducer, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - num_paths: int, - ref_texts: List[List[int]], - use_double_scores: bool = True, - nbest_scale: float = 0.5, -) -> List[List[int]]: - """It limits the maximum number of symbols per frame to 1. - - A lattice is first obtained using modified beam search, and then - we select `num_paths` linear paths from the lattice. The path - that has the minimum edit distance with the given reference transcript - is used as the output. - - This is the best result we can achieve for any nbest based rescoring - methods. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a HLG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - num_paths: - Number of paths to extract from the decoded lattice. - ref_texts: - A list-of-list of integers containing the reference transcripts. - If the decoding_graph is a trivial_graph, the integer ID is the - BPE token ID. - use_double_scores: - True to use double precision for computation. False to use - single precision. - nbest_scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - - Returns: - Return the decoded result. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - ) - - nbest = Nbest.from_lattice( - lattice=lattice, - num_paths=num_paths, - use_double_scores=use_double_scores, - nbest_scale=nbest_scale, - ) - - hyps = nbest.build_levenshtein_graphs() - refs = k2.levenshtein_graph(ref_texts, device=hyps.device) - - levenshtein_alignment = k2.levenshtein_alignment( - refs=refs, - hyps=hyps, - hyp_to_ref_map=nbest.shape.row_ids(1), - sorted_match_ref=True, - ) - - tot_scores = levenshtein_alignment.get_tot_scores( - use_double_scores=False, log_semiring=False - ) - ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) - - max_indexes = ragged_tot_scores.argmax() - - best_path = k2.index_fsa(nbest.fsa, max_indexes) - - hyps = get_texts(best_path) - return hyps - - -def fast_beam_search( - model: Transducer, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, -) -> k2.Fsa: - """It limits the maximum number of symbols per frame to 1. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a HLG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - Returns: - Return an FsaVec with axes [utt][state][arc] containing the decoded - lattice. Note: When the input graph is a TrivialGraph, the returned - lattice is actually an acceptor. - """ - assert encoder_out.ndim == 3 - - context_size = model.decoder.context_size - vocab_size = model.decoder.vocab_size - - B, T, C = encoder_out.shape - - config = k2.RnntDecodingConfig( - vocab_size=vocab_size, - decoder_history_len=context_size, - beam=beam, - max_contexts=max_contexts, - max_states=max_states, - ) - individual_streams = [] - for i in range(B): - individual_streams.append(k2.RnntDecodingStream(decoding_graph)) - decoding_streams = k2.RnntDecodingStreams(individual_streams, config) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - for t in range(T): - # shape is a RaggedShape of shape (B, context) - # contexts is a Tensor of shape (shape.NumElements(), context_size) - shape, contexts = decoding_streams.get_contexts() - # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 - contexts = contexts.to(torch.int64) - # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) - decoder_out = model.decoder(contexts, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - # current_encoder_out is of shape - # (shape.NumElements(), 1, joiner_dim) - # fmt: off - current_encoder_out = torch.index_select( - encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) - ) - # fmt: on - logits = model.joiner( - current_encoder_out.unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - logits = logits.squeeze(1).squeeze(1) - log_probs = logits.log_softmax(dim=-1) - decoding_streams.advance(log_probs) - decoding_streams.terminate_and_flush_to_streams() - lattice = decoding_streams.format_output(encoder_out_lens.tolist()) - - return lattice - - -def greedy_search( - model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int -) -> List[int]: - """Greedy search for a single utterance. - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - max_sym_per_frame: - Maximum number of symbols per frame. If it is set to 0, the WER - would be 100%. - Returns: - Return the decoded result. - """ - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - unk_id = getattr(model, "unk_id", blank_id) - - device = next(model.parameters()).device - - decoder_input = torch.tensor( - [blank_id] * context_size, device=device, dtype=torch.int64 - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - T = encoder_out.size(1) - t = 0 - hyp = [blank_id] * context_size - - # Maximum symbols per utterance. - max_sym_per_utt = 1000 - - # symbols per frame - sym_per_frame = 0 - - # symbols per utterance decoded so far - sym_per_utt = 0 - - while t < T and sym_per_utt < max_sym_per_utt: - if sym_per_frame >= max_sym_per_frame: - sym_per_frame = 0 - t += 1 - continue - - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) - # fmt: on - logits = model.joiner( - current_encoder_out, decoder_out.unsqueeze(1), project_input=False - ) - # logits is (1, 1, 1, vocab_size) - - y = logits.argmax().item() - if y not in (blank_id, unk_id): - hyp.append(y) - decoder_input = torch.tensor( - [hyp[-context_size:]], device=device - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - sym_per_utt += 1 - sym_per_frame += 1 - else: - sym_per_frame = 0 - t += 1 - hyp = hyp[context_size:] # remove blanks - - return hyp - - -def greedy_search_batch( - model: Transducer, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, -) -> List[List[int]]: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C), where N >= 1. - encoder_out_lens: - A 1-D tensor of shape (N,), containing number of valid frames in - encoder_out before padding. - Returns: - Return a list-of-list of token IDs containing the decoded results. - len(ans) equals to encoder_out.size(0). - """ - assert encoder_out.ndim == 3 - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - device = next(model.parameters()).device - - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - hyps = [[blank_id] * context_size for _ in range(N)] - - decoder_input = torch.tensor( - hyps, - device=device, - dtype=torch.int64, - ) # (N, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out: (N, 1, decoder_out_dim) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - for batch_size in batch_size_list: - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) - offset = end - - decoder_out = decoder_out[:batch_size] - - logits = model.joiner( - current_encoder_out, decoder_out.unsqueeze(1), project_input=False - ) - # logits'shape (batch_size, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v not in (blank_id, unk_id): - hyps[i].append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = [h[-context_size:] for h in hyps[:batch_size]] - decoder_input = torch.tensor( - decoder_input, - device=device, - dtype=torch.int64, - ) - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - sorted_ans = [h[context_size:] for h in hyps] - ans = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - - return ans - - -@dataclass -class Hypothesis: - # The predicted tokens so far. - # Newly predicted tokens are appended to `ys`. - ys: List[int] - - # The log prob of ys. - # It contains only one entry. - log_prob: torch.Tensor - - @property - def key(self) -> str: - """Return a string representation of self.ys""" - return "_".join(map(str, self.ys)) - - -class HypothesisList(object): - def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None: - """ - Args: - data: - A dict of Hypotheses. Its key is its `value.key`. - """ - if data is None: - self._data = {} - else: - self._data = data - - @property - def data(self) -> Dict[str, Hypothesis]: - return self._data - - def add(self, hyp: Hypothesis) -> None: - """Add a Hypothesis to `self`. - - If `hyp` already exists in `self`, its probability is updated using - `log-sum-exp` with the existed one. - - Args: - hyp: - The hypothesis to be added. - """ - key = hyp.key - if key in self: - old_hyp = self._data[key] # shallow copy - torch.logaddexp( - old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob - ) - else: - self._data[key] = hyp - - def get_most_probable(self, length_norm: bool = False) -> Hypothesis: - """Get the most probable hypothesis, i.e., the one with - the largest `log_prob`. - - Args: - length_norm: - If True, the `log_prob` of a hypothesis is normalized by the - number of tokens in it. - Returns: - Return the hypothesis that has the largest `log_prob`. - """ - if length_norm: - return max( - self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) - ) - else: - return max(self._data.values(), key=lambda hyp: hyp.log_prob) - - def remove(self, hyp: Hypothesis) -> None: - """Remove a given hypothesis. - - Caution: - `self` is modified **in-place**. - - Args: - hyp: - The hypothesis to be removed from `self`. - Note: It must be contained in `self`. Otherwise, - an exception is raised. - """ - key = hyp.key - assert key in self, f"{key} does not exist" - del self._data[key] - - def filter(self, threshold: torch.Tensor) -> "HypothesisList": - """Remove all Hypotheses whose log_prob is less than threshold. - - Caution: - `self` is not modified. Instead, a new HypothesisList is returned. - - Returns: - Return a new HypothesisList containing all hypotheses from `self` - with `log_prob` being greater than the given `threshold`. - """ - ans = HypothesisList() - for _, hyp in self._data.items(): - if hyp.log_prob > threshold: - ans.add(hyp) # shallow copy - return ans - - def topk(self, k: int) -> "HypothesisList": - """Return the top-k hypothesis.""" - hyps = list(self._data.items()) - - hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] - - ans = HypothesisList(dict(hyps)) - return ans - - def __contains__(self, key: str): - return key in self._data - - def __iter__(self): - return iter(self._data.values()) - - def __len__(self) -> int: - return len(self._data) - - def __str__(self) -> str: - s = [] - for key in self: - s.append(key) - return ", ".join(s) - - -def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: - """Return a ragged shape with axes [utt][num_hyps]. - - Args: - hyps: - len(hyps) == batch_size. It contains the current hypothesis for - each utterance in the batch. - Returns: - Return a ragged shape with 2 axes [utt][num_hyps]. Note that - the shape is on CPU. - """ - num_hyps = [len(h) for h in hyps] - - # torch.cumsum() is inclusive sum, so we put a 0 at the beginning - # to get exclusive sum later. - num_hyps.insert(0, 0) - - num_hyps = torch.tensor(num_hyps) - row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32) - ans = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=row_splits[-1].item() - ) - return ans - - -def modified_beam_search( - model: Transducer, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: int = 4, -) -> List[List[int]]: - """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. - - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C). - encoder_out_lens: - A 1-D tensor of shape (N,), containing number of valid frames in - encoder_out before padding. - beam: - Number of active paths during the beam search. - Returns: - Return a list-of-list of token IDs. ans[i] is the decoding results - for the i-th utterance. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - device = next(model.parameters()).device - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - B = [HypothesisList() for _ in range(N)] - for i in range(N): - B[i].add( - Hypothesis( - ys=[blank_id] * context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - ) - ) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - finalized_B = [] - for batch_size in batch_size_list: - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - offset = end - - finalized_B = B[batch_size:] + finalized_B - B = B[:batch_size] - - hyps_shape = _get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.cat( - [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] - ) # (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) - - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor - # as index, so we use `to(torch.int64)` below. - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - - log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) - - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_ys = hyp.ys[:] - new_token = topk_token_indexes[k] - if new_token not in (blank_id, unk_id): - new_ys.append(new_token) - - new_log_prob = topk_log_probs[k] - new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) - B[i].add(new_hyp) - - B = B + finalized_B - best_hyps = [b.get_most_probable(length_norm=True) for b in B] - - sorted_ans = [h.ys[context_size:] for h in best_hyps] - ans = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - - return ans - - -def _deprecated_modified_beam_search( - model: Transducer, - encoder_out: torch.Tensor, - beam: int = 4, -) -> List[int]: - """It limits the maximum number of symbols per frame to 1. - - It decodes only one utterance at a time. We keep it only for reference. - The function :func:`modified_beam_search` should be preferred as it - supports batch decoding. - - - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - beam: - Beam size. - Returns: - Return the decoded result. - """ - - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - - device = next(model.parameters()).device - - T = encoder_out.size(1) - - B = HypothesisList() - B.add( - Hypothesis( - ys=[blank_id] * context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - ) - ) - encoder_out = model.joiner.encoder_proj(encoder_out) - - for t in range(T): - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) - # current_encoder_out is of shape (1, 1, 1, encoder_out_dim) - # fmt: on - A = list(B) - B = HypothesisList() - - ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A]) - # ys_log_probs is of shape (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyp in A], - device=device, - dtype=torch.int64, - ) - # decoder_input is of shape (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_output is of shape (num_hyps, 1, 1, joiner_dim) - - current_encoder_out = current_encoder_out.expand( - decoder_out.size(0), 1, 1, -1 - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) - # logits is of shape (num_hyps, 1, 1, vocab_size) - logits = logits.squeeze(1).squeeze(1) - - # now logits is of shape (num_hyps, vocab_size) - log_probs = logits.log_softmax(dim=-1) - - log_probs.add_(ys_log_probs) - - log_probs = log_probs.reshape(-1) - topk_log_probs, topk_indexes = log_probs.topk(beam) - - # topk_hyp_indexes are indexes into `A` - topk_hyp_indexes = topk_indexes // logits.size(-1) - topk_token_indexes = topk_indexes % logits.size(-1) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = topk_hyp_indexes.tolist() - topk_token_indexes = topk_token_indexes.tolist() - - for i in range(len(topk_hyp_indexes)): - hyp = A[topk_hyp_indexes[i]] - new_ys = hyp.ys[:] - new_token = topk_token_indexes[i] - if new_token not in (blank_id, unk_id): - new_ys.append(new_token) - new_log_prob = topk_log_probs[i] - new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) - B.add(new_hyp) - - best_hyp = B.get_most_probable(length_norm=True) - ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks - - return ys - - -def beam_search( - model: Transducer, - encoder_out: torch.Tensor, - beam: int = 4, -) -> List[int]: - """ - It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf - - espnet/nets/beam_search_transducer.py#L247 is used as a reference. - - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - beam: - Beam size. - Returns: - Return the decoded result. - """ - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - - device = next(model.parameters()).device - - decoder_input = torch.tensor( - [blank_id] * context_size, - device=device, - dtype=torch.int64, - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - T = encoder_out.size(1) - t = 0 - - B = HypothesisList() - B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0)) - - max_sym_per_utt = 20000 - - sym_per_utt = 0 - - decoder_cache: Dict[str, torch.Tensor] = {} - - while t < T and sym_per_utt < max_sym_per_utt: - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) - # fmt: on - A = B - B = HypothesisList() - - joint_cache: Dict[str, torch.Tensor] = {} - - # TODO(fangjun): Implement prefix search to update the `log_prob` - # of hypotheses in A - - while True: - y_star = A.get_most_probable() - A.remove(y_star) - - cached_key = y_star.key - - if cached_key not in decoder_cache: - decoder_input = torch.tensor( - [y_star.ys[-context_size:]], - device=device, - dtype=torch.int64, - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - decoder_cache[cached_key] = decoder_out - else: - decoder_out = decoder_cache[cached_key] - - cached_key += f"-t-{t}" - if cached_key not in joint_cache: - logits = model.joiner( - current_encoder_out, - decoder_out.unsqueeze(1), - project_input=False, - ) - - # TODO(fangjun): Scale the blank posterior - log_prob = logits.log_softmax(dim=-1) - # log_prob is (1, 1, 1, vocab_size) - log_prob = log_prob.squeeze() - # Now log_prob is (vocab_size,) - joint_cache[cached_key] = log_prob - else: - log_prob = joint_cache[cached_key] - - # First, process the blank symbol - skip_log_prob = log_prob[blank_id] - new_y_star_log_prob = y_star.log_prob + skip_log_prob - - # ys[:] returns a copy of ys - B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) - - # Second, process other non-blank labels - values, indices = log_prob.topk(beam + 1) - for i, v in zip(indices.tolist(), values.tolist()): - if i in (blank_id, unk_id): - continue - new_ys = y_star.ys + [i] - new_log_prob = y_star.log_prob + v - A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob)) - - # Check whether B contains more than "beam" elements more probable - # than the most probable in A - A_most_probable = A.get_most_probable() - - kept_B = B.filter(A_most_probable.log_prob) - - if len(kept_B) >= beam: - B = kept_B.topk(beam) - break - - t += 1 - - best_hyp = B.get_most_probable(length_norm=True) - ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks - return ys diff --git a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/decode_stream.py b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/decode_stream.py deleted file mode 100755 index e522943c0..000000000 --- a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/decode_stream.py +++ /dev/null @@ -1,146 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from typing import List, Optional, Tuple - -import k2 -import torch -from beam_search import Hypothesis, HypothesisList - -from icefall.utils import AttributeDict - - -class DecodeStream(object): - def __init__( - self, - params: AttributeDict, - cut_id: str, - initial_states: List[torch.Tensor], - decoding_graph: Optional[k2.Fsa] = None, - device: torch.device = torch.device("cpu"), - ) -> None: - """ - Args: - initial_states: - Initial decode states of the model, e.g. the return value of - `get_init_state` in conformer.py - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a HLG. - Used only when decoding_method is fast_beam_search. - device: - The device to run this stream. - """ - if params.decoding_method == "fast_beam_search": - assert decoding_graph is not None - assert device == decoding_graph.device - - self.params = params - self.cut_id = cut_id - self.LOG_EPS = math.log(1e-10) - - self.states = initial_states - - # It contains a 2-D tensors representing the feature frames. - self.features: torch.Tensor = None - - self.num_frames: int = 0 - # how many frames have been processed. (before subsampling). - # we only modify this value in `func:get_feature_frames`. - self.num_processed_frames: int = 0 - - self._done: bool = False - - # The transcript of current utterance. - self.ground_truth: str = "" - - # The decoding result (partial or final) of current utterance. - self.hyp: List = [] - - # how many frames have been processed, after subsampling (i.e. a - # cumulative sum of the second return value of - # encoder.streaming_forward - self.done_frames: int = 0 - - self.pad_length = (params.right_context + 2) * params.subsampling_factor + 3 - - if params.decoding_method == "greedy_search": - self.hyp = [params.blank_id] * params.context_size - elif params.decoding_method == "modified_beam_search": - self.hyps = HypothesisList() - self.hyps.add( - Hypothesis( - ys=[params.blank_id] * params.context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - ) - ) - elif params.decoding_method == "fast_beam_search": - # The rnnt_decoding_stream for fast_beam_search. - self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( - decoding_graph - ) - else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - - @property - def done(self) -> bool: - """Return True if all the features are processed.""" - return self._done - - @property - def id(self) -> str: - return self.cut_id - - def set_features( - self, - features: torch.Tensor, - ) -> None: - """Set features tensor of current utterance.""" - assert features.dim() == 2, features.dim() - self.features = torch.nn.functional.pad( - features, - (0, 0, 0, self.pad_length), - mode="constant", - value=self.LOG_EPS, - ) - self.num_frames = self.features.size(0) - - def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]: - """Consume chunk_size frames of features""" - chunk_length = chunk_size + self.pad_length - - ret_length = min(self.num_frames - self.num_processed_frames, chunk_length) - - ret_features = self.features[ - self.num_processed_frames : self.num_processed_frames + ret_length # noqa - ] - - self.num_processed_frames += chunk_size - if self.num_processed_frames >= self.num_frames: - self._done = True - - return ret_features, ret_length - - def decoding_result(self) -> List[int]: - """Obtain current decoding result.""" - if self.params.decoding_method == "greedy_search": - return self.hyp[self.params.context_size :] # noqa - elif self.params.decoding_method == "modified_beam_search": - best_hyp = self.hyps.get_most_probable(length_norm=True) - return best_hyp.ys[self.params.context_size :] # noqa - else: - assert self.params.decoding_method == "fast_beam_search" - return self.hyp diff --git a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/decode_stream.py b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/decode_stream.py new file mode 120000 index 000000000..d59ef95f7 --- /dev/null +++ b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/decode_stream.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/decode_stream.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/decoder.py b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/decoder.py deleted file mode 100644 index b6d94aaf1..000000000 --- a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/decoder.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn as nn -import torch.nn.functional as F -from scaling import ScaledConv1d, ScaledEmbedding - - -class Decoder(nn.Module): - """This class modifies the stateless decoder from the following paper: - - RNN-transducer with stateless prediction network - https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 - - It removes the recurrent connection from the decoder, i.e., the prediction - network. Different from the above paper, it adds an extra Conv1d - right after the embedding layer. - - TODO: Implement https://arxiv.org/pdf/2109.07513.pdf - """ - - def __init__( - self, - vocab_size: int, - decoder_dim: int, - blank_id: int, - context_size: int, - ): - """ - Args: - vocab_size: - Number of tokens of the modeling unit including blank. - decoder_dim: - Dimension of the input embedding, and of the decoder output. - blank_id: - The ID of the blank symbol. - context_size: - Number of previous words to use to predict the next word. - 1 means bigram; 2 means trigram. n means (n+1)-gram. - """ - super().__init__() - - self.embedding = ScaledEmbedding( - num_embeddings=vocab_size, - embedding_dim=decoder_dim, - padding_idx=blank_id, - ) - self.blank_id = blank_id - - assert context_size >= 1, context_size - self.context_size = context_size - self.vocab_size = vocab_size - if context_size > 1: - self.conv = ScaledConv1d( - in_channels=decoder_dim, - out_channels=decoder_dim, - kernel_size=context_size, - padding=0, - groups=decoder_dim, - bias=False, - ) - - def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: - """ - Args: - y: - A 2-D tensor of shape (N, U). - need_pad: - True to left pad the input. Should be True during training. - False to not pad the input. Should be False during inference. - Returns: - Return a tensor of shape (N, U, decoder_dim). - """ - y = y.to(torch.int64) - embedding_out = self.embedding(y) - if self.context_size > 1: - embedding_out = embedding_out.permute(0, 2, 1) - if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) - else: - # During inference time, there is no need to do extra padding - # as we only need one output - assert embedding_out.size(-1) == self.context_size - embedding_out = self.conv(embedding_out) - embedding_out = embedding_out.permute(0, 2, 1) - embedding_out = F.relu(embedding_out) - return embedding_out diff --git a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/decoder.py b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/decoder.py new file mode 120000 index 000000000..722e1c894 --- /dev/null +++ b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/decoder.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/encoder_interface.py b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/encoder_interface.py deleted file mode 100644 index 257facce4..000000000 --- a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/encoder_interface.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Tuple - -import torch -import torch.nn as nn - - -class EncoderInterface(nn.Module): - def forward( - self, x: torch.Tensor, x_lens: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - x: - A tensor of shape (batch_size, input_seq_len, num_features) - containing the input features. - x_lens: - A tensor of shape (batch_size,) containing the number of frames - in `x` before padding. - Returns: - Return a tuple containing two tensors: - - encoder_out, a tensor of (batch_size, out_seq_len, output_dim) - containing unnormalized probabilities, i.e., the output of a - linear layer. - - encoder_out_lens, a tensor of shape (batch_size,) containing - the number of frames in `encoder_out` before padding. - """ - raise NotImplementedError("Please implement it in a subclass") diff --git a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/encoder_interface.py b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/encoder_interface.py new file mode 120000 index 000000000..f58253127 --- /dev/null +++ b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/export.py b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/export.py deleted file mode 100755 index 513388113..000000000 --- a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/export.py +++ /dev/null @@ -1,243 +0,0 @@ -import argparse -import logging -from pathlib import Path - -import sentencepiece as spm -import torch -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import str2bool - -# python pruned_transducer_stateless5/export.py --exp-dir pruned_transducer_stateless5/exp_streaming --streaming-model 1 --causal-convolution 1 --jit 1 --epoch 10 --avg 5 --bpe-model data/lang_bpe_2000/bpe.model -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=10, - help="""It specifies the checkpoint to use for averaging. - Note: Epoch counts from 0. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=5, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless5/exp_streaming", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_2000/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=True, - help="""True to save a model after applying torch.jit.script. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--streaming-model", - type=str2bool, - default=True, - help="""Whether to export a streaming model, if the models in exp-dir - are streaming model, this should be True. - """, - ) - - add_model_arguments(parser) - - return parser - - -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - if params.streaming_model: - assert params.causal_convolution - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - model.to(device) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to("cpu") - model.eval() - - if params.jit: - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - filename = params.exp_dir / "cpu_jit.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torch.jit.script") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/export.py b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/export.py new file mode 120000 index 000000000..14fd0531d --- /dev/null +++ b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/export.py @@ -0,0 +1 @@ +../../ST/zipformer/export.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/joiner.py b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/joiner.py deleted file mode 100644 index d5f4a7bd6..000000000 --- a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/joiner.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn as nn -from scaling import ScaledLinear -from icefall.utils import is_jit_tracing - -class Joiner(nn.Module): - def __init__( - self, - encoder_dim: int, - decoder_dim: int, - joiner_dim: int, - vocab_size: int, - ): - super().__init__() - - self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim) - self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim) - self.output_linear = ScaledLinear(joiner_dim, vocab_size) - - def forward( - self, - encoder_out: torch.Tensor, - decoder_out: torch.Tensor, - project_input: bool = True, - ) -> torch.Tensor: - """ - Args: - encoder_out: - Output from the encoder. Its shape is (N, T, s_range, C). - decoder_out: - Output from the decoder. Its shape is (N, T, s_range, C). - project_input: - If true, apply input projections encoder_proj and decoder_proj. - If this is false, it is the user's responsibility to do this - manually. - Returns: - Return a tensor of shape (N, T, s_range, C). - """ -# assert encoder_out.ndim == decoder_out.ndim == 4 -# assert encoder_out.shape[:-1] == decoder_out.shape[:-1] - -# if project_input: -# logit = self.encoder_proj(encoder_out) + self.decoder_proj( -# decoder_out -# ) - if not is_jit_tracing(): - assert encoder_out.ndim == decoder_out.ndim - assert encoder_out.ndim in (2, 4) - assert encoder_out.shape == decoder_out.shape - - if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) - else: - logit = encoder_out + decoder_out - - logit = self.output_linear(torch.tanh(logit)) - - return logit diff --git a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/joiner.py b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/joiner.py new file mode 120000 index 000000000..9052f3cbb --- /dev/null +++ b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/joiner.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/model.py b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/model.py deleted file mode 100644 index 272d06c37..000000000 --- a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/model.py +++ /dev/null @@ -1,207 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import Tuple - -import k2 -import torch -import torch.nn as nn -from encoder_interface import EncoderInterface -from scaling import ScaledLinear - -from icefall.utils import add_sos - - -class Transducer(nn.Module): - """It implements https://arxiv.org/pdf/1211.3711.pdf - "Sequence Transduction with Recurrent Neural Networks" - """ - - def __init__( - self, - encoder: EncoderInterface, - decoder: nn.Module, - joiner: nn.Module, - encoder_dim: int, - decoder_dim: int, - joiner_dim: int, - vocab_size: int, - ): - """ - Args: - encoder: - It is the transcription network in the paper. Its accepts - two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). - It returns two tensors: `logits` of shape (N, T, encoder_dm) and - `logit_lens` of shape (N,). - decoder: - It is the prediction network in the paper. Its input shape - is (N, U) and its output shape is (N, U, decoder_dim). - It should contain one attribute: `blank_id`. - joiner: - It has two inputs with shapes: (N, T, encoder_dim) and - (N, U, decoder_dim). - Its output shape is (N, T, U, vocab_size). Note that its output - contains unnormalized probs, i.e., not processed by log-softmax. - """ - super().__init__() - assert isinstance(encoder, EncoderInterface), type(encoder) - assert hasattr(decoder, "blank_id") - - self.encoder = encoder - self.decoder = decoder - self.joiner = joiner - - self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) - self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - y: k2.RaggedTensor, - prune_range: int = 5, - am_scale: float = 0.0, - lm_scale: float = 0.0, - warmup: float = 1.0, - reduction: str = "sum", - delay_penalty: float = 0.0, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - x: - A 3-D tensor of shape (N, T, C). - x_lens: - A 1-D tensor of shape (N,). It contains the number of frames in `x` - before padding. - y: - A ragged tensor with 2 axes [utt][label]. It contains labels of each - utterance. - prune_range: - The prune range for rnnt loss, it means how many symbols(context) - we are considering for each frame to compute the loss. - am_scale: - The scale to smooth the loss with am (output of encoder network) - part - lm_scale: - The scale to smooth the loss with lm (output of predictor network) - part - warmup: - A value warmup >= 0 that determines which modules are active, values - warmup > 1 "are fully warmed up" and all modules will be active. - reduction: - "sum" to sum the losses over all utterances in the batch. - "none" to return the loss in a 1-D tensor for each utterance - in the batch. - delay_penalty: - A constant value used to penalize symbol delay, to encourage - streaming models to emit symbols earlier. - See https://github.com/k2-fsa/k2/issues/955 and - https://arxiv.org/pdf/2211.00490.pdf for more details. - Returns: - Returns: - Return the transducer loss. - - Note: - Regarding am_scale & lm_scale, it will make the loss-function one of - the form: - lm_scale * lm_probs + am_scale * am_probs + - (1-lm_scale-am_scale) * combined_probs - """ - assert reduction in ("sum", "none"), reduction - assert x.ndim == 3, x.shape - assert x_lens.ndim == 1, x_lens.shape - assert y.num_axes == 2, y.num_axes - - assert x.size(0) == x_lens.size(0) == y.dim0 - - encoder_out, x_lens = self.encoder(x, x_lens, warmup=warmup) - assert torch.all(x_lens > 0) - - # Now for the decoder, i.e., the prediction network - row_splits = y.shape.row_splits(1) - y_lens = row_splits[1:] - row_splits[:-1] - - blank_id = self.decoder.blank_id - sos_y = add_sos(y, sos_id=blank_id) - - # sos_y_padded: [B, S + 1], start with SOS. - sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) - - # decoder_out: [B, S + 1, decoder_dim] - decoder_out = self.decoder(sos_y_padded) - - # Note: y does not start with SOS - # y_padded : [B, S] - y_padded = y.pad(mode="constant", padding_value=0) - - y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) - boundary[:, 2] = y_lens - boundary[:, 3] = x_lens - - lm = self.simple_lm_proj(decoder_out) - am = self.simple_am_proj(encoder_out) - - with torch.cuda.amp.autocast(enabled=False): - simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( - lm=lm.float(), - am=am.float(), - symbols=y_padded, - termination_symbol=blank_id, - lm_only_scale=lm_scale, - am_only_scale=am_scale, - boundary=boundary, - reduction=reduction, - delay_penalty=delay_penalty, - return_grad=True, - ) - - # ranges : [B, T, prune_range] - ranges = k2.get_rnnt_prune_ranges( - px_grad=px_grad, - py_grad=py_grad, - boundary=boundary, - s_range=prune_range, - ) - - # am_pruned : [B, T, prune_range, encoder_dim] - # lm_pruned : [B, T, prune_range, decoder_dim] - am_pruned, lm_pruned = k2.do_rnnt_pruning( - am=self.joiner.encoder_proj(encoder_out), - lm=self.joiner.decoder_proj(decoder_out), - ranges=ranges, - ) - - # logits : [B, T, prune_range, vocab_size] - - # project_input=False since we applied the decoder's input projections - # prior to do_rnnt_pruning (this is an optimization for speed). - logits = self.joiner(am_pruned, lm_pruned, project_input=False) - - with torch.cuda.amp.autocast(enabled=False): - pruned_loss = k2.rnnt_loss_pruned( - logits=logits.float(), - symbols=y_padded, - ranges=ranges, - termination_symbol=blank_id, - boundary=boundary, - delay_penalty=delay_penalty, - reduction=reduction, - ) - - return (simple_loss, pruned_loss) diff --git a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/model.py b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/model.py new file mode 120000 index 000000000..a99e74334 --- /dev/null +++ b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/model.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/optim.py b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/optim.py deleted file mode 100644 index 432bf8220..000000000 --- a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/optim.py +++ /dev/null @@ -1,331 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) -# -# See ../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import List, Optional, Union - -import torch -from torch.optim import Optimizer - - -class Eve(Optimizer): - r""" - Implements Eve algorithm. This is a modified version of AdamW with a special - way of setting the weight-decay / shrinkage-factor, which is designed to make the - rms of the parameters approach a particular target_rms (default: 0.1). This is - for use with networks with 'scaled' versions of modules (see scaling.py), which - will be close to invariant to the absolute scale on the parameter matrix. - - The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. - The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. - Eve is unpublished so far. - - Arguments: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay coefficient (default: 3e-4; - this value means that the weight would decay significantly after - about 3k minibatches. Is not multiplied by learning rate, but - is conditional on RMS-value of parameter being > target_rms. - target_rms (float, optional): target root-mean-square value of - parameters, if they fall below this we will stop applying weight decay. - - - .. _Adam\: A Method for Stochastic Optimization: - https://arxiv.org/abs/1412.6980 - .. _Decoupled Weight Decay Regularization: - https://arxiv.org/abs/1711.05101 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ - """ - - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.98), - eps=1e-8, - weight_decay=1e-3, - target_rms=0.1, - ): - - if not 0.0 <= lr: - raise ValueError("Invalid learning rate: {}".format(lr)) - if not 0.0 <= eps: - raise ValueError("Invalid epsilon value: {}".format(eps)) - if not 0.0 <= betas[0] < 1.0: - raise ValueError( - "Invalid beta parameter at index 0: {}".format(betas[0]) - ) - if not 0.0 <= betas[1] < 1.0: - raise ValueError( - "Invalid beta parameter at index 1: {}".format(betas[1]) - ) - if not 0 <= weight_decay <= 0.1: - raise ValueError( - "Invalid weight_decay value: {}".format(weight_decay) - ) - if not 0 < target_rms <= 10.0: - raise ValueError("Invalid target_rms value: {}".format(target_rms)) - defaults = dict( - lr=lr, - betas=betas, - eps=eps, - weight_decay=weight_decay, - target_rms=target_rms, - ) - super(Eve, self).__init__(params, defaults) - - def __setstate__(self, state): - super(Eve, self).__setstate__(state) - - @torch.no_grad() - def step(self, closure=None): - """Performs a single optimization step. - - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - for p in group["params"]: - if p.grad is None: - continue - - # Perform optimization step - grad = p.grad - if grad.is_sparse: - raise RuntimeError( - "AdamW does not support sparse gradients" - ) - - state = self.state[p] - - # State initialization - if len(state) == 0: - state["step"] = 0 - # Exponential moving average of gradient values - state["exp_avg"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) - # Exponential moving average of squared gradient values - state["exp_avg_sq"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) - - exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] - - beta1, beta2 = group["betas"] - - state["step"] += 1 - bias_correction1 = 1 - beta1 ** state["step"] - bias_correction2 = 1 - beta2 ** state["step"] - - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( - group["eps"] - ) - - step_size = group["lr"] / bias_correction1 - target_rms = group["target_rms"] - weight_decay = group["weight_decay"] - - if p.numel() > 1: - # avoid applying this weight-decay on "scaling factors" - # (which are scalar). - is_above_target_rms = p.norm() > ( - target_rms * (p.numel() ** 0.5) - ) - p.mul_(1 - (weight_decay * is_above_target_rms)) - p.addcdiv_(exp_avg, denom, value=-step_size) - - return loss - - -class LRScheduler(object): - """ - Base-class for learning rate schedulers where the learning-rate depends on both the - batch and the epoch. - """ - - def __init__(self, optimizer: Optimizer, verbose: bool = False): - # Attach optimizer - if not isinstance(optimizer, Optimizer): - raise TypeError( - "{} is not an Optimizer".format(type(optimizer).__name__) - ) - self.optimizer = optimizer - self.verbose = verbose - - for group in optimizer.param_groups: - group.setdefault("initial_lr", group["lr"]) - - self.base_lrs = [ - group["initial_lr"] for group in optimizer.param_groups - ] - - self.epoch = 0 - self.batch = 0 - - def state_dict(self): - """Returns the state of the scheduler as a :class:`dict`. - - It contains an entry for every variable in self.__dict__ which - is not the optimizer. - """ - return { - "base_lrs": self.base_lrs, - "epoch": self.epoch, - "batch": self.batch, - } - - def load_state_dict(self, state_dict): - """Loads the schedulers state. - - Args: - state_dict (dict): scheduler state. Should be an object returned - from a call to :meth:`state_dict`. - """ - self.__dict__.update(state_dict) - - def get_last_lr(self) -> List[float]: - """Return last computed learning rate by current scheduler. Will be a list of float.""" - return self._last_lr - - def get_lr(self): - # Compute list of learning rates from self.epoch and self.batch and - # self.base_lrs; this must be overloaded by the user. - # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ] - raise NotImplementedError - - def step_batch(self, batch: Optional[int] = None) -> None: - # Step the batch index, or just set it. If `batch` is specified, it - # must be the batch index from the start of training, i.e. summed over - # all epochs. - # You can call this in any order; if you don't provide 'batch', it should - # of course be called once per batch. - if batch is not None: - self.batch = batch - else: - self.batch = self.batch + 1 - self._set_lrs() - - def step_epoch(self, epoch: Optional[int] = None): - # Step the epoch index, or just set it. If you provide the 'epoch' arg, - # you should call this at the start of the epoch; if you don't provide the 'epoch' - # arg, you should call it at the end of the epoch. - if epoch is not None: - self.epoch = epoch - else: - self.epoch = self.epoch + 1 - self._set_lrs() - - def _set_lrs(self): - values = self.get_lr() - assert len(values) == len(self.optimizer.param_groups) - - for i, data in enumerate(zip(self.optimizer.param_groups, values)): - param_group, lr = data - param_group["lr"] = lr - self.print_lr(self.verbose, i, lr) - self._last_lr = [group["lr"] for group in self.optimizer.param_groups] - - def print_lr(self, is_verbose, group, lr): - """Display the current learning rate.""" - if is_verbose: - print( - f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate" - f" of group {group} to {lr:.4e}." - ) - - -class Eden(LRScheduler): - """ - Eden scheduler. - lr = initial_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 * - (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) - - E.g. suggest initial-lr = 0.003 (passed to optimizer). - - Args: - optimizer: the optimizer to change the learning rates on - lr_batches: the number of batches after which we start significantly - decreasing the learning rate, suggest 5000. - lr_epochs: the number of epochs after which we start significantly - decreasing the learning rate, suggest 6 if you plan to do e.g. - 20 to 40 epochs, but may need smaller number if dataset is huge - and you will do few epochs. - """ - - def __init__( - self, - optimizer: Optimizer, - lr_batches: Union[int, float], - lr_epochs: Union[int, float], - verbose: bool = False, - ): - super(Eden, self).__init__(optimizer, verbose) - self.lr_batches = lr_batches - self.lr_epochs = lr_epochs - - def get_lr(self): - factor = ( - (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 - ) ** -0.25 * ( - ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) - ** -0.25 - ) - return [x * factor for x in self.base_lrs] - - -def _test_eden(): - m = torch.nn.Linear(100, 100) - optim = Eve(m.parameters(), lr=0.003) - - scheduler = Eden(optim, lr_batches=30, lr_epochs=2, verbose=True) - - for epoch in range(10): - scheduler.step_epoch(epoch) # sets epoch to `epoch` - - for step in range(20): - x = torch.randn(200, 100).detach() - x.requires_grad = True - y = m(x) - dy = torch.randn(200, 100).detach() - f = (y * dy).sum() - f.backward() - - optim.step() - scheduler.step_batch() - optim.zero_grad() - print("last lr = ", scheduler.get_last_lr()) - print("state dict = ", scheduler.state_dict()) - - -if __name__ == "__main__": - _test_eden() diff --git a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/optim.py b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/optim.py new file mode 120000 index 000000000..0a2f285aa --- /dev/null +++ b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/optim.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/scaling.py b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/scaling.py deleted file mode 100644 index 5ee4bab98..000000000 --- a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/scaling.py +++ /dev/null @@ -1,719 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import collections -from itertools import repeat -from typing import Optional, Tuple - -import torch -import torch.nn as nn -from torch import Tensor - - -def _ntuple(n): - def parse(x): - if isinstance(x, collections.Iterable): - return x - return tuple(repeat(x, n)) - - return parse - - -_single = _ntuple(1) -_pair = _ntuple(2) - - -class ActivationBalancerFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x: Tensor, - channel_dim: int, - min_positive: float, # e.g. 0.05 - max_positive: float, # e.g. 0.95 - max_factor: float, # e.g. 0.01 - min_abs: float, # e.g. 0.2 - max_abs: float, # e.g. 100.0 - ) -> Tensor: - if x.requires_grad: - if channel_dim < 0: - channel_dim += x.ndim - sum_dims = [d for d in range(x.ndim) if d != channel_dim] - xgt0 = x > 0 - proportion_positive = torch.mean( - xgt0.to(x.dtype), dim=sum_dims, keepdim=True - ) - factor1 = ( - (min_positive - proportion_positive).relu() - * (max_factor / min_positive) - if min_positive != 0.0 - else 0.0 - ) - factor2 = ( - (proportion_positive - max_positive).relu() - * (max_factor / (max_positive - 1.0)) - if max_positive != 1.0 - else 0.0 - ) - factor = factor1 + factor2 - if isinstance(factor, float): - factor = torch.zeros_like(proportion_positive) - - mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True) - below_threshold = mean_abs < min_abs - above_threshold = mean_abs > max_abs - - ctx.save_for_backward( - factor, xgt0, below_threshold, above_threshold - ) - ctx.max_factor = max_factor - ctx.sum_dims = sum_dims - return x - - @staticmethod - def backward( - ctx, x_grad: Tensor - ) -> Tuple[Tensor, None, None, None, None, None, None]: - factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors - dtype = x_grad.dtype - scale_factor = ( - (below_threshold.to(dtype) - above_threshold.to(dtype)) - * (xgt0.to(dtype) - 0.5) - * (ctx.max_factor * 2.0) - ) - - neg_delta_grad = x_grad.abs() * (factor + scale_factor) - return x_grad - neg_delta_grad, None, None, None, None, None, None - - -class BasicNorm(torch.nn.Module): - """ - This is intended to be a simpler, and hopefully cheaper, replacement for - LayerNorm. The observation this is based on, is that Transformer-type - networks, especially with pre-norm, sometimes seem to set one of the - feature dimensions to a large constant value (e.g. 50), which "defeats" - the LayerNorm because the output magnitude is then not strongly dependent - on the other (useful) features. Presumably the weight and bias of the - LayerNorm are required to allow it to do this. - - So the idea is to introduce this large constant value as an explicit - parameter, that takes the role of the "eps" in LayerNorm, so the network - doesn't have to do this trick. We make the "eps" learnable. - - Args: - num_channels: the number of channels, e.g. 512. - channel_dim: the axis/dimension corresponding to the channel, - interprted as an offset from the input's ndim if negative. - shis is NOT the num_channels; it should typically be one of - {-2, -1, 0, 1, 2, 3}. - eps: the initial "epsilon" that we add as ballast in: - scale = ((input_vec**2).mean() + epsilon)**-0.5 - Note: our epsilon is actually large, but we keep the name - to indicate the connection with conventional LayerNorm. - learn_eps: if true, we learn epsilon; if false, we keep it - at the initial value. - """ - - def __init__( - self, - num_channels: int, - channel_dim: int = -1, # CAUTION: see documentation. - eps: float = 0.25, - learn_eps: bool = True, - ) -> None: - super(BasicNorm, self).__init__() - self.num_channels = num_channels - self.channel_dim = channel_dim - if learn_eps: - self.eps = nn.Parameter(torch.tensor(eps).log().detach()) - else: - self.register_buffer("eps", torch.tensor(eps).log().detach()) - - def forward(self, x: Tensor) -> Tensor: - assert x.shape[self.channel_dim] == self.num_channels - scales = ( - torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) - + self.eps.exp() - ) ** -0.5 - return x * scales - - -class ScaledLinear(nn.Linear): - """ - A modified version of nn.Linear where the parameters are scaled before - use, via: - weight = self.weight * self.weight_scale.exp() - bias = self.bias * self.bias_scale.exp() - - Args: - Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False. - - initial_scale: you can override this if you want to increase - or decrease the initial magnitude of the module's output - (affects the initialization of weight_scale and bias_scale). - Another option, if you want to do something like this, is - to re-initialize the parameters. - initial_speed: this affects how fast the parameter will - learn near the start of training; you can set it to a - value less than one if you suspect that a module - is contributing to instability near the start of training. - Nnote: regardless of the use of this option, it's best to - use schedulers like Noam that have a warm-up period. - Alternatively you can set it to more than 1 if you want it to - initially train faster. Must be greater than 0. - """ - - def __init__( - self, - *args, - initial_scale: float = 1.0, - initial_speed: float = 1.0, - **kwargs - ): - super(ScaledLinear, self).__init__(*args, **kwargs) - initial_scale = torch.tensor(initial_scale).log() - self.weight_scale = nn.Parameter(initial_scale.clone().detach()) - if self.bias is not None: - self.bias_scale = nn.Parameter(initial_scale.clone().detach()) - else: - self.register_parameter("bias_scale", None) - - self._reset_parameters( - initial_speed - ) # Overrides the reset_parameters in nn.Linear - - def _reset_parameters(self, initial_speed: float): - std = 0.1 / initial_speed - a = (3 ** 0.5) * std - nn.init.uniform_(self.weight, -a, a) - if self.bias is not None: - nn.init.constant_(self.bias, 0.0) - fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) - with torch.no_grad(): - self.weight_scale += torch.tensor(scale / std).log() - - def get_weight(self): - return self.weight * self.weight_scale.exp() - - def get_bias(self): - if self.bias is None or self.bias_scale is None: - return None - - return self.bias * self.bias_scale.exp() - - def forward(self, input: Tensor) -> Tensor: - return torch.nn.functional.linear( - input, self.get_weight(), self.get_bias() - ) - - -class ScaledConv1d(nn.Conv1d): - # See docs for ScaledLinear - def __init__( - self, - *args, - initial_scale: float = 1.0, - initial_speed: float = 1.0, - **kwargs - ): - super(ScaledConv1d, self).__init__(*args, **kwargs) - initial_scale = torch.tensor(initial_scale).log() - self.weight_scale = nn.Parameter(initial_scale.clone().detach()) - if self.bias is not None: - self.bias_scale = nn.Parameter(initial_scale.clone().detach()) - else: - self.register_parameter("bias_scale", None) - self._reset_parameters( - initial_speed - ) # Overrides the reset_parameters in base class - - def _reset_parameters(self, initial_speed: float): - std = 0.1 / initial_speed - a = (3 ** 0.5) * std - nn.init.uniform_(self.weight, -a, a) - if self.bias is not None: - nn.init.constant_(self.bias, 0.0) - fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) - with torch.no_grad(): - self.weight_scale += torch.tensor(scale / std).log() - - def get_weight(self): - return self.weight * self.weight_scale.exp() - - def get_bias(self): - bias = self.bias - bias_scale = self.bias_scale - if bias is None or bias_scale is None: - return None - return bias * bias_scale.exp() - - def forward(self, input: Tensor) -> Tensor: - F = torch.nn.functional - if self.padding_mode != "zeros": - return F.conv1d( - F.pad( - input, - self._reversed_padding_repeated_twice, - mode=self.padding_mode, - ), - self.get_weight(), - self.get_bias(), - self.stride, - (0,), - self.dilation, - self.groups, - ) - return F.conv1d( - input, - self.get_weight(), - self.get_bias(), - self.stride, - self.padding, - self.dilation, - self.groups, - ) - - -class ScaledConv2d(nn.Conv2d): - # See docs for ScaledLinear - def __init__( - self, - *args, - initial_scale: float = 1.0, - initial_speed: float = 1.0, - **kwargs - ): - super(ScaledConv2d, self).__init__(*args, **kwargs) - initial_scale = torch.tensor(initial_scale).log() - self.weight_scale = nn.Parameter(initial_scale.clone().detach()) - if self.bias is not None: - self.bias_scale = nn.Parameter(initial_scale.clone().detach()) - else: - self.register_parameter("bias_scale", None) - self._reset_parameters( - initial_speed - ) # Overrides the reset_parameters in base class - - def _reset_parameters(self, initial_speed: float): - std = 0.1 / initial_speed - a = (3 ** 0.5) * std - nn.init.uniform_(self.weight, -a, a) - if self.bias is not None: - nn.init.constant_(self.bias, 0.0) - fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) - with torch.no_grad(): - self.weight_scale += torch.tensor(scale / std).log() - - def get_weight(self): - return self.weight * self.weight_scale.exp() - - def get_bias(self): - # see https://github.com/pytorch/pytorch/issues/24135 - bias = self.bias - bias_scale = self.bias_scale - if bias is None or bias_scale is None: - return None - return bias * bias_scale.exp() - - def _conv_forward(self, input, weight): - F = torch.nn.functional - if self.padding_mode != "zeros": - return F.conv2d( - F.pad( - input, - self._reversed_padding_repeated_twice, - mode=self.padding_mode, - ), - weight, - self.get_bias(), - self.stride, - (0, 0), - self.dilation, - self.groups, - ) - return F.conv2d( - input, - weight, - self.get_bias(), - self.stride, - self.padding, - self.dilation, - self.groups, - ) - - def forward(self, input: Tensor) -> Tensor: - return self._conv_forward(input, self.get_weight()) - - -class ActivationBalancer(torch.nn.Module): - """ - Modifies the backpropped derivatives of a function to try to encourage, for - each channel, that it is positive at least a proportion `threshold` of the - time. It does this by multiplying negative derivative values by up to - (1+max_factor), and positive derivative values by up to (1-max_factor), - interpolated from 1 at the threshold to those extremal values when none - of the inputs are positive. - - - Args: - channel_dim: the dimension/axis corresponding to the channel, e.g. - -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. - min_positive: the minimum, per channel, of the proportion of the time - that (x > 0), below which we start to modify the derivatives. - max_positive: the maximum, per channel, of the proportion of the time - that (x > 0), above which we start to modify the derivatives. - max_factor: the maximum factor by which we modify the derivatives for - either the sign constraint or the magnitude constraint; - e.g. with max_factor=0.02, the the derivatives would be multiplied by - values in the range [0.98..1.02]. - min_abs: the minimum average-absolute-value per channel, which - we allow, before we start to modify the derivatives to prevent - this. - max_abs: the maximum average-absolute-value per channel, which - we allow, before we start to modify the derivatives to prevent - this. - """ - - def __init__( - self, - channel_dim: int, - min_positive: float = 0.05, - max_positive: float = 0.95, - max_factor: float = 0.01, - min_abs: float = 0.2, - max_abs: float = 100.0, - ): - super(ActivationBalancer, self).__init__() - self.channel_dim = channel_dim - self.min_positive = min_positive - self.max_positive = max_positive - self.max_factor = max_factor - self.min_abs = min_abs - self.max_abs = max_abs - - def forward(self, x: Tensor) -> Tensor: - if torch.jit.is_scripting(): - return x - - return ActivationBalancerFunction.apply( - x, - self.channel_dim, - self.min_positive, - self.max_positive, - self.max_factor, - self.min_abs, - self.max_abs, - ) - - -class DoubleSwishFunction(torch.autograd.Function): - """ - double_swish(x) = x * torch.sigmoid(x-1) - This is a definition, originally motivated by its close numerical - similarity to swish(swish(x)), where swish(x) = x * sigmoid(x). - - Memory-efficient derivative computation: - double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) - double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). - Now, s'(x) = s(x) * (1-s(x)). - double_swish'(x) = x * s'(x) + s(x). - = x * s(x) * (1-s(x)) + s(x). - = double_swish(x) * (1-s(x)) + s(x) - ... so we just need to remember s(x) but not x itself. - """ - - @staticmethod - def forward(ctx, x: Tensor) -> Tensor: - x = x.detach() - s = torch.sigmoid(x - 1.0) - y = x * s - ctx.save_for_backward(s, y) - return y - - @staticmethod - def backward(ctx, y_grad: Tensor) -> Tensor: - s, y = ctx.saved_tensors - return (y * (1 - s) + s) * y_grad - - -class DoubleSwish(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: - """Return double-swish activation function which is an approximation to Swish(Swish(x)), - that we approximate closely with x * sigmoid(x-1). - """ - if torch.jit.is_scripting(): - return x * torch.sigmoid(x - 1.0) - return DoubleSwishFunction.apply(x) - - -class ScaledEmbedding(nn.Module): - r"""This is a modified version of nn.Embedding that introduces a learnable scale - on the parameters. Note: due to how we initialize it, it's best used with - schedulers like Noam that have a warmup period. - - It is a simple lookup table that stores embeddings of a fixed dictionary and size. - - This module is often used to store word embeddings and retrieve them using indices. - The input to the module is a list of indices, and the output is the corresponding - word embeddings. - - Args: - num_embeddings (int): size of the dictionary of embeddings - embedding_dim (int): the size of each embedding vector - padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx` - (initialized to zeros) whenever it encounters the index. - max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` - is renormalized to have norm :attr:`max_norm`. - norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. - scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of - the words in the mini-batch. Default ``False``. - sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. - See Notes for more details regarding sparse gradients. - - initial_speed (float, optional): This affects how fast the parameter will - learn near the start of training; you can set it to a value less than - one if you suspect that a module is contributing to instability near - the start of training. Nnote: regardless of the use of this option, - it's best to use schedulers like Noam that have a warm-up period. - Alternatively you can set it to more than 1 if you want it to - initially train faster. Must be greater than 0. - - - Attributes: - weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim) - initialized from :math:`\mathcal{N}(0, 1)` - - Shape: - - Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract - - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}` - - .. note:: - Keep in mind that only a limited number of optimizers support - sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`), - :class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`) - - .. note:: - With :attr:`padding_idx` set, the embedding vector at - :attr:`padding_idx` is initialized to all zeros. However, note that this - vector can be modified afterwards, e.g., using a customized - initialization method, and thus changing the vector used to pad the - output. The gradient for this vector from :class:`~torch.nn.Embedding` - is always zero. - - Examples:: - - >>> # an Embedding module containing 10 tensors of size 3 - >>> embedding = nn.Embedding(10, 3) - >>> # a batch of 2 samples of 4 indices each - >>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]]) - >>> embedding(input) - tensor([[[-0.0251, -1.6902, 0.7172], - [-0.6431, 0.0748, 0.6969], - [ 1.4970, 1.3448, -0.9685], - [-0.3677, -2.7265, -0.1685]], - - [[ 1.4970, 1.3448, -0.9685], - [ 0.4362, -0.4004, 0.9400], - [-0.6431, 0.0748, 0.6969], - [ 0.9124, -2.3616, 1.1151]]]) - - - >>> # example with padding_idx - >>> embedding = nn.Embedding(10, 3, padding_idx=0) - >>> input = torch.LongTensor([[0,2,0,5]]) - >>> embedding(input) - tensor([[[ 0.0000, 0.0000, 0.0000], - [ 0.1535, -2.0309, 0.9315], - [ 0.0000, 0.0000, 0.0000], - [-0.1655, 0.9897, 0.0635]]]) - - """ - __constants__ = [ - "num_embeddings", - "embedding_dim", - "padding_idx", - "scale_grad_by_freq", - "sparse", - ] - - num_embeddings: int - embedding_dim: int - padding_idx: int - scale_grad_by_freq: bool - weight: Tensor - sparse: bool - - def __init__( - self, - num_embeddings: int, - embedding_dim: int, - padding_idx: Optional[int] = None, - scale_grad_by_freq: bool = False, - sparse: bool = False, - initial_speed: float = 1.0, - ) -> None: - super(ScaledEmbedding, self).__init__() - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - if padding_idx is not None: - if padding_idx > 0: - assert ( - padding_idx < self.num_embeddings - ), "Padding_idx must be within num_embeddings" - elif padding_idx < 0: - assert ( - padding_idx >= -self.num_embeddings - ), "Padding_idx must be within num_embeddings" - padding_idx = self.num_embeddings + padding_idx - self.padding_idx = padding_idx - self.scale_grad_by_freq = scale_grad_by_freq - - self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() - self.sparse = sparse - - self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) - self.reset_parameters(initial_speed) - - def reset_parameters(self, initial_speed: float = 1.0) -> None: - std = 0.1 / initial_speed - nn.init.normal_(self.weight, std=std) - nn.init.constant_(self.scale, torch.tensor(1.0 / std).log()) - - if self.padding_idx is not None: - with torch.no_grad(): - self.weight[self.padding_idx].fill_(0) - - def forward(self, input: Tensor) -> Tensor: - F = torch.nn.functional - scale = self.scale.exp() - if input.numel() < self.num_embeddings: - return ( - F.embedding( - input, - self.weight, - self.padding_idx, - None, - 2.0, # None, 2.0 relate to normalization - self.scale_grad_by_freq, - self.sparse, - ) - * scale - ) - else: - return F.embedding( - input, - self.weight * scale, - self.padding_idx, - None, - 2.0, # None, 2.0 relates to normalization - self.scale_grad_by_freq, - self.sparse, - ) - - def extra_repr(self) -> str: - s = "{num_embeddings}, {embedding_dim}, scale={scale}" - if self.padding_idx is not None: - s += ", padding_idx={padding_idx}" - if self.scale_grad_by_freq is not False: - s += ", scale_grad_by_freq={scale_grad_by_freq}" - if self.sparse is not False: - s += ", sparse=True" - return s.format(**self.__dict__) - - -def _test_activation_balancer_sign(): - probs = torch.arange(0, 1, 0.01) - N = 1000 - x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) - x = x.detach() - x.requires_grad = True - m = ActivationBalancer( - channel_dim=0, - min_positive=0.05, - max_positive=0.95, - max_factor=0.2, - min_abs=0.0, - ) - - y_grad = torch.sign(torch.randn(probs.numel(), N)) - - y = m(x) - y.backward(gradient=y_grad) - print("_test_activation_balancer_sign: x = ", x) - print("_test_activation_balancer_sign: y grad = ", y_grad) - print("_test_activation_balancer_sign: x grad = ", x.grad) - - -def _test_activation_balancer_magnitude(): - magnitudes = torch.arange(0, 1, 0.01) - N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( - -1 - ) - x = x.detach() - x.requires_grad = True - m = ActivationBalancer( - channel_dim=0, - min_positive=0.0, - max_positive=1.0, - max_factor=0.2, - min_abs=0.2, - max_abs=0.8, - ) - - y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) - - y = m(x) - y.backward(gradient=y_grad) - print("_test_activation_balancer_magnitude: x = ", x) - print("_test_activation_balancer_magnitude: y grad = ", y_grad) - print("_test_activation_balancer_magnitude: x grad = ", x.grad) - - -def _test_basic_norm(): - num_channels = 128 - m = BasicNorm(num_channels=num_channels, channel_dim=1) - - x = torch.randn(500, num_channels) - - y = m(x) - - assert y.shape == x.shape - x_rms = (x ** 2).mean().sqrt() - y_rms = (y ** 2).mean().sqrt() - print("x rms = ", x_rms) - print("y rms = ", y_rms) - assert y_rms < x_rms - assert y_rms > 0.5 * x_rms - - -def _test_double_swish_deriv(): - x = torch.randn(10, 12, dtype=torch.double) * 0.5 - x.requires_grad = True - m = DoubleSwish() - torch.autograd.gradcheck(m, x) - - -if __name__ == "__main__": - _test_activation_balancer_sign() - _test_activation_balancer_magnitude() - _test_basic_norm() - _test_double_swish_deriv() diff --git a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/scaling.py b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/scaling.py new file mode 120000 index 000000000..c10cdfe12 --- /dev/null +++ b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/scaling.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/scaling_converter.py b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/scaling_converter.py deleted file mode 100644 index 06a81656c..000000000 --- a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/scaling_converter.py +++ /dev/null @@ -1,314 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file provides functions to convert `ScaledLinear`, `ScaledConv1d`, -`ScaledConv2d`, and `ScaledEmbedding` to their non-scaled counterparts: -`nn.Linear`, `nn.Conv1d`, `nn.Conv2d`, and `nn.Embedding`. -The scaled version are required only in the training time. It simplifies our -life by converting them to their non-scaled version during inference. -""" - -import copy -import re -from typing import List - -import torch -import torch.nn as nn -from lstmp import LSTMP -from scaling import ( - ActivationBalancer, - BasicNorm, - ScaledConv1d, - ScaledConv2d, - ScaledEmbedding, - ScaledLinear, - ScaledLSTM, -) - - -class NonScaledNorm(nn.Module): - """See BasicNorm for doc""" - - def __init__( - self, - num_channels: int, - eps_exp: float, - channel_dim: int = -1, # CAUTION: see documentation. - ): - super().__init__() - self.num_channels = num_channels - self.channel_dim = channel_dim - self.eps_exp = eps_exp - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if not torch.jit.is_tracing(): - assert x.shape[self.channel_dim] == self.num_channels - scales = ( - torch.mean(x * x, dim=self.channel_dim, - keepdim=True) + self.eps_exp - ).pow(-0.5) - return x * scales - - -def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> nn.Linear: - """Convert an instance of ScaledLinear to nn.Linear. - Args: - scaled_linear: - The layer to be converted. - Returns: - Return a linear layer. It satisfies: - scaled_linear(x) == linear(x) - for any given input tensor `x`. - """ - assert isinstance(scaled_linear, ScaledLinear), type(scaled_linear) - - weight = scaled_linear.get_weight() - bias = scaled_linear.get_bias() - has_bias = bias is not None - - linear = torch.nn.Linear( - in_features=scaled_linear.in_features, - out_features=scaled_linear.out_features, - bias=True, # otherwise, it throws errors when converting to PNNX format - # device=weight.device, # Pytorch version before v1.9.0 does not have - # this argument. Comment out for now, we will - # see if it will raise error for versions - # after v1.9.0 - ) - linear.weight.data.copy_(weight) - - if has_bias: - linear.bias.data.copy_(bias) - else: - linear.bias.data.zero_() - - return linear - - -def scaled_conv1d_to_conv1d(scaled_conv1d: ScaledConv1d) -> nn.Conv1d: - """Convert an instance of ScaledConv1d to nn.Conv1d. - Args: - scaled_conv1d: - The layer to be converted. - Returns: - Return an instance of nn.Conv1d that has the same `forward()` behavior - of the given `scaled_conv1d`. - """ - assert isinstance(scaled_conv1d, ScaledConv1d), type(scaled_conv1d) - - weight = scaled_conv1d.get_weight() - bias = scaled_conv1d.get_bias() - has_bias = bias is not None - - conv1d = nn.Conv1d( - in_channels=scaled_conv1d.in_channels, - out_channels=scaled_conv1d.out_channels, - kernel_size=scaled_conv1d.kernel_size, - stride=scaled_conv1d.stride, - padding=scaled_conv1d.padding, - dilation=scaled_conv1d.dilation, - groups=scaled_conv1d.groups, - bias=scaled_conv1d.bias is not None, - padding_mode=scaled_conv1d.padding_mode, - ) - - conv1d.weight.data.copy_(weight) - if has_bias: - conv1d.bias.data.copy_(bias) - - return conv1d - - -def scaled_conv2d_to_conv2d(scaled_conv2d: ScaledConv2d) -> nn.Conv2d: - """Convert an instance of ScaledConv2d to nn.Conv2d. - Args: - scaled_conv2d: - The layer to be converted. - Returns: - Return an instance of nn.Conv2d that has the same `forward()` behavior - of the given `scaled_conv2d`. - """ - assert isinstance(scaled_conv2d, ScaledConv2d), type(scaled_conv2d) - - weight = scaled_conv2d.get_weight() - bias = scaled_conv2d.get_bias() - has_bias = bias is not None - - conv2d = nn.Conv2d( - in_channels=scaled_conv2d.in_channels, - out_channels=scaled_conv2d.out_channels, - kernel_size=scaled_conv2d.kernel_size, - stride=scaled_conv2d.stride, - padding=scaled_conv2d.padding, - dilation=scaled_conv2d.dilation, - groups=scaled_conv2d.groups, - bias=scaled_conv2d.bias is not None, - padding_mode=scaled_conv2d.padding_mode, - ) - - conv2d.weight.data.copy_(weight) - if has_bias: - conv2d.bias.data.copy_(bias) - - return conv2d - - -def scaled_embedding_to_embedding( - scaled_embedding: ScaledEmbedding, -) -> nn.Embedding: - """Convert an instance of ScaledEmbedding to nn.Embedding. - Args: - scaled_embedding: - The layer to be converted. - Returns: - Return an instance of nn.Embedding that has the same `forward()` behavior - of the given `scaled_embedding`. - """ - assert isinstance(scaled_embedding, ScaledEmbedding), type( - scaled_embedding) - embedding = nn.Embedding( - num_embeddings=scaled_embedding.num_embeddings, - embedding_dim=scaled_embedding.embedding_dim, - padding_idx=scaled_embedding.padding_idx, - scale_grad_by_freq=scaled_embedding.scale_grad_by_freq, - sparse=scaled_embedding.sparse, - ) - weight = scaled_embedding.weight - scale = scaled_embedding.scale - - embedding.weight.data.copy_(weight * scale.exp()) - - return embedding - - -def convert_basic_norm(basic_norm: BasicNorm) -> NonScaledNorm: - assert isinstance(basic_norm, BasicNorm), type(BasicNorm) - norm = NonScaledNorm( - num_channels=basic_norm.num_channels, - eps_exp=basic_norm.eps.data.exp().item(), - channel_dim=basic_norm.channel_dim, - ) - return norm - - -def scaled_lstm_to_lstm(scaled_lstm: ScaledLSTM) -> nn.LSTM: - """Convert an instance of ScaledLSTM to nn.LSTM. - Args: - scaled_lstm: - The layer to be converted. - Returns: - Return an instance of nn.LSTM that has the same `forward()` behavior - of the given `scaled_lstm`. - """ - assert isinstance(scaled_lstm, ScaledLSTM), type(scaled_lstm) - lstm = nn.LSTM( - input_size=scaled_lstm.input_size, - hidden_size=scaled_lstm.hidden_size, - num_layers=scaled_lstm.num_layers, - bias=scaled_lstm.bias, - batch_first=scaled_lstm.batch_first, - dropout=scaled_lstm.dropout, - bidirectional=scaled_lstm.bidirectional, - proj_size=scaled_lstm.proj_size, - ) - - assert lstm._flat_weights_names == scaled_lstm._flat_weights_names - for idx in range(len(scaled_lstm._flat_weights_names)): - scaled_weight = scaled_lstm._flat_weights[idx] * \ - scaled_lstm._scales[idx].exp() - lstm._flat_weights[idx].data.copy_(scaled_weight) - - return lstm - - -# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa -# get_submodule was added to nn.Module at v1.9.0 -def get_submodule(model, target): - if target == "": - return model - atoms: List[str] = target.split(".") - mod: torch.nn.Module = model - for item in atoms: - if not hasattr(mod, item): - raise AttributeError( - mod._get_name() + " has no " "attribute `" + item + "`" - ) - mod = getattr(mod, item) - if not isinstance(mod, torch.nn.Module): - raise AttributeError("`" + item + "` is not " "an nn.Module") - return mod - - -def convert_scaled_to_non_scaled( - model: nn.Module, - inplace: bool = False, - is_onnx: bool = False, -): - """Convert `ScaledLinear`, `ScaledConv1d`, and `ScaledConv2d` - in the given modle to their unscaled version `nn.Linear`, `nn.Conv1d`, - and `nn.Conv2d`. - Args: - model: - The model to be converted. - inplace: - If True, the input model is modified inplace. - If False, the input model is copied and we modify the copied version. - is_onnx: - If True, we are going to export the model to ONNX. In this case, - we will convert nn.LSTM with proj_size to LSTMP. - Return: - Return a model without scaled layers. - """ - if not inplace: - model = copy.deepcopy(model) - - excluded_patterns = r"(self|src)_attn\.(in|out)_proj" - p = re.compile(excluded_patterns) - - d = {} - for name, m in model.named_modules(): - if isinstance(m, ScaledLinear): - if p.search(name) is not None: - continue - d[name] = scaled_linear_to_linear(m) - elif isinstance(m, ScaledConv1d): - d[name] = scaled_conv1d_to_conv1d(m) - elif isinstance(m, ScaledConv2d): - d[name] = scaled_conv2d_to_conv2d(m) - elif isinstance(m, ScaledEmbedding): - d[name] = scaled_embedding_to_embedding(m) - elif isinstance(m, BasicNorm): - d[name] = convert_basic_norm(m) - elif isinstance(m, ScaledLSTM): - if is_onnx: - d[name] = LSTMP(scaled_lstm_to_lstm(m)) - # See - # https://github.com/pytorch/pytorch/issues/47887 - # d[name] = torch.jit.script(LSTMP(scaled_lstm_to_lstm(m))) - else: - d[name] = scaled_lstm_to_lstm(m) - elif isinstance(m, ActivationBalancer): - d[name] = nn.Identity() - - for k, v in d.items(): - if "." in k: - parent, child = k.rsplit(".", maxsplit=1) - setattr(get_submodule(model, parent), child, v) - else: - setattr(model, k, v) - - return model diff --git a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/scaling_converter.py b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/scaling_converter.py new file mode 120000 index 000000000..e58473a04 --- /dev/null +++ b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless5/scaling_converter.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/streaming_beam_search.py b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/streaming_beam_search.py deleted file mode 100644 index e6e0fb1c8..000000000 --- a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/streaming_beam_search.py +++ /dev/null @@ -1,282 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings -from typing import List - -import k2 -import torch -import torch.nn as nn -from beam_search import Hypothesis, HypothesisList, get_hyps_shape -from decode_stream import DecodeStream - -from icefall.decode import one_best_decoding -from icefall.utils import get_texts - - -def greedy_search( - model: nn.Module, - encoder_out: torch.Tensor, - streams: List[DecodeStream], -) -> None: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C), where N >= 1. - streams: - A list of Stream objects. - """ - assert len(streams) == encoder_out.size(0) - assert encoder_out.ndim == 3 - - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - device = model.device - T = encoder_out.size(1) - - decoder_input = torch.tensor( - [stream.hyp[-context_size:] for stream in streams], - device=device, - dtype=torch.int64, - ) - # decoder_out is of shape (N, 1, decoder_out_dim) - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - for t in range(T): - # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) - current_encoder_out = encoder_out[:, t : t + 1, :] # noqa - - logits = model.joiner( - current_encoder_out.unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - # logits'shape (batch_size, vocab_size) - logits = logits.squeeze(1).squeeze(1) - - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v != blank_id: - streams[i].hyp.append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = torch.tensor( - [stream.hyp[-context_size:] for stream in streams], - device=device, - dtype=torch.int64, - ) - decoder_out = model.decoder( - decoder_input, - need_pad=False, - ) - decoder_out = model.joiner.decoder_proj(decoder_out) - - -def modified_beam_search( - model: nn.Module, - encoder_out: torch.Tensor, - streams: List[DecodeStream], - num_active_paths: int = 4, -) -> None: - """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. - - Args: - model: - The RNN-T model. - encoder_out: - A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of - the encoder model. - streams: - A list of stream objects. - num_active_paths: - Number of active paths during the beam search. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert len(streams) == encoder_out.size(0) - - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - device = next(model.parameters()).device - batch_size = len(streams) - T = encoder_out.size(1) - - B = [stream.hyps for stream in streams] - - for t in range(T): - current_encoder_out = encoder_out[:, t].unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.stack( - [hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0 - ) # (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out is of shape (num_hyps, 1, 1, decoder_output_dim) - - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor - # as index, so we use `to(torch.int64)` below. - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, encoder_out_dim) - - logits = model.joiner(current_encoder_out, decoder_out, project_input=False) - # logits is of shape (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) - - log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_ys = hyp.ys[:] - new_token = topk_token_indexes[k] - if new_token != blank_id: - new_ys.append(new_token) - - new_log_prob = topk_log_probs[k] - new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) - B[i].add(new_hyp) - - for i in range(batch_size): - streams[i].hyps = B[i] - - -def fast_beam_search_one_best( - model: nn.Module, - encoder_out: torch.Tensor, - processed_lens: torch.Tensor, - streams: List[DecodeStream], - beam: float, - max_states: int, - max_contexts: int, -) -> None: - """It limits the maximum number of symbols per frame to 1. - - A lattice is first generated by Fsa-based beam search, then we get the - recognition by applying shortest path on the lattice. - - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - processed_lens: - A tensor of shape (N,) containing the number of processed frames - in `encoder_out` before padding. - streams: - A list of stream objects. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - """ - assert encoder_out.ndim == 3 - B, T, C = encoder_out.shape - assert B == len(streams) - - context_size = model.decoder.context_size - vocab_size = model.decoder.vocab_size - - config = k2.RnntDecodingConfig( - vocab_size=vocab_size, - decoder_history_len=context_size, - beam=beam, - max_contexts=max_contexts, - max_states=max_states, - ) - individual_streams = [] - for i in range(B): - individual_streams.append(streams[i].rnnt_decoding_stream) - decoding_streams = k2.RnntDecodingStreams(individual_streams, config) - - for t in range(T): - # shape is a RaggedShape of shape (B, context) - # contexts is a Tensor of shape (shape.NumElements(), context_size) - shape, contexts = decoding_streams.get_contexts() - # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 - contexts = contexts.to(torch.int64) - # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) - decoder_out = model.decoder(contexts, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - # current_encoder_out is of shape - # (shape.NumElements(), 1, joiner_dim) - # fmt: off - current_encoder_out = torch.index_select( - encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) - ) - # fmt: on - logits = model.joiner( - current_encoder_out.unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - logits = logits.squeeze(1).squeeze(1) - log_probs = logits.log_softmax(dim=-1) - decoding_streams.advance(log_probs) - - decoding_streams.terminate_and_flush_to_streams() - - lattice = decoding_streams.format_output(processed_lens.tolist()) - best_path = one_best_decoding(lattice) - hyp_tokens = get_texts(best_path) - - for i in range(B): - streams[i].hyp = hyp_tokens[i] diff --git a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/streaming_beam_search.py b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/streaming_beam_search.py new file mode 120000 index 000000000..2f76638ac --- /dev/null +++ b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/streaming_beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/streaming_decode.py deleted file mode 100755 index 0a61c9493..000000000 --- a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/streaming_decode.py +++ /dev/null @@ -1,608 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Usage: -./pruned_transducer_stateless/streaming_decode.py \ - --epoch 28 \ - --avg 15 \ - --decode-chunk-size 8 \ - --left-context 32 \ - --right-context 0 \ - --exp-dir ./pruned_transducer_stateless/exp \ - --decoding_method greedy_search \ - --num-decode-streams 1000 -""" - -import argparse -import logging -import math -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import numpy as np -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import MGB2AsrDataModule -from decode_stream import DecodeStream -from kaldifeat import Fbank, FbankOptions -from lhotse import CutSet -from streaming_beam_search import ( - fast_beam_search_one_best, - greedy_search, - modified_beam_search, -) -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.utils import ( - AttributeDict, - str2bool, - setup_logger, - store_transcripts, - write_error_stats, -) -import pdb - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 0. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless5/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_2000/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Supported decoding methods are: - greedy_search - modified_beam_search - fast_beam_search - """, - ) - - parser.add_argument( - "--num-active-paths", - type=int, - default=4, - help="""An interger indicating how many candidates we will keep for each - frame. Used only when --decoding-method is modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=32, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--decode-chunk-size", - type=int, - default=16, - help="The chunk size for decoding (in frames after subsampling)", - ) - - parser.add_argument( - "--left-context", - type=int, - default=64, - help="left context can be seen during decoding (in frames after subsampling)", - ) - - parser.add_argument( - "--right-context", - type=int, - default=0, - help="right context can be seen during decoding (in frames after subsampling)", - ) - - parser.add_argument( - "--num-decode-streams", - type=int, - default=2000, - help="The number of streams that can be decoded parallel.", - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_chunk( - params: AttributeDict, - model: nn.Module, - decode_streams: List[DecodeStream], -) -> List[int]: - """Decode one chunk frames of features for each decode_streams and - return the indexes of finished streams in a List. - - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - decode_streams: - A List of DecodeStream, each belonging to a utterance. - Returns: - Return a List containing which DecodeStreams are finished. - """ - device = model.device - - features = [] - feature_lens = [] - states = [] - processed_lens = [] - - for stream in decode_streams: - feat, feat_len = stream.get_feature_frames( - params.decode_chunk_size * params.subsampling_factor - ) - features.append(feat) - feature_lens.append(feat_len) - states.append(stream.states) - processed_lens.append(stream.done_frames) - - feature_lens = torch.tensor(feature_lens, device=device) - features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) - - # if T is less than 7 there will be an error in time reduction layer, - # because we subsample features with ((x_len - 1) // 2 - 1) // 2 - # we plus 2 here because we will cut off one frame on each size of - # encoder_embed output as they see invalid paddings. so we need extra 2 - # frames. - tail_length = 7 + (2 + params.right_context) * params.subsampling_factor - if features.size(1) < tail_length: - pad_length = tail_length - features.size(1) - feature_lens += pad_length - features = torch.nn.functional.pad( - features, - (0, 0, 0, pad_length), - mode="constant", - value=LOG_EPS, - ) - - states = [ - torch.stack([x[0] for x in states], dim=2), - torch.stack([x[1] for x in states], dim=2), - ] - - processed_lens = torch.tensor(processed_lens, device=device) - - encoder_out, encoder_out_lens, states = model.encoder.streaming_forward( - x=features, - x_lens=feature_lens, - states=states, - left_context=params.left_context, - right_context=params.right_context, - processed_lens=processed_lens, - ) - - if params.decoding_method == "greedy_search": - greedy_search(model=model, encoder_out=encoder_out, - streams=decode_streams) - elif params.decoding_method == "fast_beam_search": - processed_lens = processed_lens + encoder_out_lens - fast_beam_search_one_best( - model=model, - encoder_out=encoder_out, - processed_lens=processed_lens, - streams=decode_streams, - beam=params.beam, - max_states=params.max_states, - max_contexts=params.max_contexts, - ) - elif params.decoding_method == "modified_beam_search": - modified_beam_search( - model=model, - streams=decode_streams, - encoder_out=encoder_out, - num_active_paths=params.num_active_paths, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}") - - states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] - - finished_streams = [] - for i in range(len(decode_streams)): - decode_streams[i].states = [states[0][i], states[1][i]] - decode_streams[i].done_frames += encoder_out_lens[i] - if decode_streams[i].done: - finished_streams.append(i) - - return finished_streams - - -def decode_dataset( - cuts: CutSet, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: - """Decode dataset. - - Args: - cuts: - Lhotse Cutset containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - device = model.device - - opts = FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = 16000 - opts.mel_opts.num_bins = 80 - - log_interval = 100 - - decode_results = [] - # Contain decode streams currently running. - decode_streams = [] - initial_states = model.encoder.get_init_state( - params.left_context, device=device) - for num, cut_ in enumerate(cuts): - # each utterance has a DecodeStream. - for cut in cut_["supervisions"]["cut"]: - # pdb.set_trace() - decode_stream = DecodeStream( - params=params, - cut_id=cut.id, - initial_states=initial_states, - decoding_graph=decoding_graph, - device=device, - ) - - audio: np.ndarray = cut.load_audio() - # audio.shape: (1, num_samples) - assert len(audio.shape) == 2 - assert audio.shape[0] == 1, "Should be single channel" - assert audio.dtype == np.float32, audio.dtype - - # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" - - samples = torch.from_numpy(audio).squeeze(0) - - fbank = Fbank(opts) - decode_stream.set_features(fbank(samples.to(device))) - decode_stream.ground_truth = cut.supervisions[0].text - - decode_streams.append(decode_stream) - - while len(decode_streams) >= params.num_decode_streams: - # pdb.set_trace() - finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams - ) - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - decode_streams[i].ground_truth.split(), - sp.decode( - decode_streams[i].decoding_result()).split(), - ) - ) - del decode_streams[i] - - if num % log_interval == 0: - logging.info(f"Cuts processed until now is {num}.") - - # decode final chunks of last sequences - while len(decode_streams): - finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams - ) - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - decode_streams[i].ground_truth.split(), - sp.decode( - decode_streams[i].decoding_result()).split(), - ) - ) - del decode_streams[i] - - if params.decoding_method == "greedy_search": - key = "greedy_search" - elif params.decoding_method == "fast_beam_search": - key = ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ) - elif params.decoding_method == "modified_beam_search": - key = f"num_active_paths_{params.num_active_paths}" - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}") - - return {key: decode_results} - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / - f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - # sort results so we can easily compare the difference between two - # recognition results - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / - f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - MGB2AsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - params.res_dir = params.exp_dir / "streaming" / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - # for streaming - params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" - params.suffix += f"-left-context-{params.left_context}" - params.suffix += f"-right-context-{params.right_context}" - - # for fast_beam_search - if params.decoding_method == "fast_beam_search": - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - params.causal_convolution = True - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - - model.to(device) - model.eval() - model.device = device - - decoding_graph = None - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - MGB2 = MGB2AsrDataModule(args) - - test_cuts = MGB2.test_cuts() - dev_cuts = MGB2.dev_cuts() - - test_dl = MGB2.test_dataloaders(test_cuts) - dev_dl = MGB2.test_dataloaders(dev_cuts) - - test_sets = ["test", "dev"] - test_all_dl = [test_dl, dev_dl] - - for test_set, test_dl in zip(test_sets, test_all_dl): - results_dict = decode_dataset( - cuts=test_dl, - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/streaming_decode.py new file mode 120000 index 000000000..f29284163 --- /dev/null +++ b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/streaming_decode.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/train.py b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/train.py index 5095c16ce..68f01dfac 100755 --- a/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/train.py +++ b/egs/iwslt22_ta/ASR/pruned_transducer_stateless5/train.py @@ -227,9 +227,15 @@ def get_parser(): parser.add_argument( "--bpe-model", type=str, - default="data/lang_bpe_1000/bpe.model", + default="data/lang_bpe_ta_1000/bpe.model", help="Path to source data BPE model", ) + parser.add_argument( + "--bpe-tgt-model", + type=str, + default="data/lang_bpe_en_1000/bpe.model", + help="Path to target data BPE model", + ) parser.add_argument( "--initial-lr", type=float, @@ -611,6 +617,7 @@ def compute_loss( params: AttributeDict, model: Union[nn.Module, DDP], sp: spm.SentencePieceProcessor, + sp_tgt: spm.SentencePieceProcessor, batch: dict, is_training: bool, warmup: float = 1.0, @@ -648,8 +655,11 @@ def compute_loss( feature_lens = supervisions["num_frames"].to(device) #pdb.set_trace() texts = batch["supervisions"]["text"] + tgt_texts = batch["supervisions"]["tgt_text"] y = sp.encode(texts, out_type=int) + y_tgt = sp_tgt.encode(tgt_texts, out_type=int) y = k2.RaggedTensor(y).to(device) + y_tgt = k2.RaggedTensor(y_tgt).to(device) with torch.set_grad_enabled(is_training): simple_loss, pruned_loss = model( @@ -726,6 +736,7 @@ def compute_validation_loss( params: AttributeDict, model: Union[nn.Module, DDP], sp: spm.SentencePieceProcessor, + sp_tgt: spm.SentencePieceProcessor, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, ) -> MetricsTracker: @@ -739,6 +750,7 @@ def compute_validation_loss( params=params, model=model, sp=sp, + sp_tgt=sp_tgt, batch=batch, is_training=False, ) @@ -762,6 +774,7 @@ def train_one_epoch( optimizer: torch.optim.Optimizer, scheduler: LRSchedulerType, sp: spm.SentencePieceProcessor, + sp_tgt: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, scaler: GradScaler, @@ -821,6 +834,7 @@ def train_one_epoch( params=params, model=model, sp=sp, + sp_tgt=sp_tgt, batch=batch, is_training=True, warmup=( @@ -913,6 +927,7 @@ def train_one_epoch( params=params, model=model, sp=sp, + sp_tgt=sp_tgt, valid_dl=valid_dl, world_size=world_size, ) @@ -992,7 +1007,9 @@ def run(rank, world_size, args): logging.info(f"Device: {device}") sp = spm.SentencePieceProcessor() + sp_tgt = spm.SentencePieceProcessor() sp.load(params.bpe_model) + sp_tgt.load(params.bpe_tgt_model) # pdb.set_trace() # is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") @@ -1122,6 +1139,7 @@ def run(rank, world_size, args): train_dl=train_dl, optimizer=optimizer, sp=sp, + sp_tgt=sp_tgt, params=params, warmup=0.0 if params.start_epoch == 1 else 1.0, ) @@ -1149,6 +1167,7 @@ def run(rank, world_size, args): optimizer=optimizer, scheduler=scheduler, sp=sp, + sp_tgt=sp_tgt, train_dl=train_dl, valid_dl=valid_dl, scaler=scaler, @@ -1217,6 +1236,7 @@ def scan_pessimistic_batches_for_oom( train_dl: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, sp: spm.SentencePieceProcessor, + sp_tgt: spm.SentencePieceProcessor, params: AttributeDict, warmup: float, ): @@ -1238,6 +1258,7 @@ def scan_pessimistic_batches_for_oom( params=params, model=model, sp=sp, + sp_tgt=sp_tgt, batch=batch, is_training=True, warmup=warmup, diff --git a/egs/iwslt22_ta/ASR/shared b/egs/iwslt22_ta/ASR/shared new file mode 120000 index 000000000..4cbd91a7e --- /dev/null +++ b/egs/iwslt22_ta/ASR/shared @@ -0,0 +1 @@ +../../../icefall/shared \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/zipformer/beam_search.py b/egs/iwslt22_ta/ASR/zipformer/beam_search.py deleted file mode 100644 index 1eaa38049..000000000 --- a/egs/iwslt22_ta/ASR/zipformer/beam_search.py +++ /dev/null @@ -1,3015 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang -# Xiaoyu Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings -from dataclasses import dataclass, field -from typing import Dict, List, Optional, Tuple, Union - -import k2 -import sentencepiece as spm -import torch -from model import Transducer - -from icefall import NgramLm, NgramLmStateCost -from icefall.decode import Nbest, one_best_decoding -from icefall.lm_wrapper import LmScorer -from icefall.rnn_lm.model import RnnLmModel -from icefall.transformer_lm.model import TransformerLM -from icefall.utils import ( - DecodingResults, - add_eos, - add_sos, - get_texts, - get_texts_with_timestamp, -) - - -def fast_beam_search_one_best( - model: Transducer, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - temperature: float = 1.0, - subtract_ilme: bool = False, - ilme_scale: float = 0.1, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """It limits the maximum number of symbols per frame to 1. - - A lattice is first obtained using fast beam search, and then - the shortest path within the lattice is used as the final output. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - temperature=temperature, - subtract_ilme=subtract_ilme, - ilme_scale=ilme_scale, - ) - - best_path = one_best_decoding(lattice) - - if not return_timestamps: - return get_texts(best_path) - else: - return get_texts_with_timestamp(best_path) - - -def fast_beam_search_nbest_LG( - model: Transducer, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - num_paths: int, - nbest_scale: float = 0.5, - use_double_scores: bool = True, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """It limits the maximum number of symbols per frame to 1. - - The process to get the results is: - - (1) Use fast beam search to get a lattice - - (2) Select `num_paths` paths from the lattice using k2.random_paths() - - (3) Unique the selected paths - - (4) Intersect the selected paths with the lattice and compute the - shortest path from the intersection result - - (5) The path with the largest score is used as the decoding output. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - num_paths: - Number of paths to extract from the decoded lattice. - nbest_scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - use_double_scores: - True to use double precision for computation. False to use - single precision. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - temperature=temperature, - ) - - nbest = Nbest.from_lattice( - lattice=lattice, - num_paths=num_paths, - use_double_scores=use_double_scores, - nbest_scale=nbest_scale, - ) - - # The following code is modified from nbest.intersect() - word_fsa = k2.invert(nbest.fsa) - if hasattr(lattice, "aux_labels"): - # delete token IDs as it is not needed - del word_fsa.aux_labels - word_fsa.scores.zero_() - word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa) - path_to_utt_map = nbest.shape.row_ids(1) - - if hasattr(lattice, "aux_labels"): - # lattice has token IDs as labels and word IDs as aux_labels. - # inv_lattice has word IDs as labels and token IDs as aux_labels - inv_lattice = k2.invert(lattice) - inv_lattice = k2.arc_sort(inv_lattice) - else: - inv_lattice = k2.arc_sort(lattice) - - if inv_lattice.shape[0] == 1: - path_lattice = k2.intersect_device( - inv_lattice, - word_fsa_with_epsilon_loops, - b_to_a_map=torch.zeros_like(path_to_utt_map), - sorted_match_a=True, - ) - else: - path_lattice = k2.intersect_device( - inv_lattice, - word_fsa_with_epsilon_loops, - b_to_a_map=path_to_utt_map, - sorted_match_a=True, - ) - - # path_lattice has word IDs as labels and token IDs as aux_labels - path_lattice = k2.top_sort(k2.connect(path_lattice)) - tot_scores = path_lattice.get_tot_scores( - use_double_scores=use_double_scores, - log_semiring=True, # Note: we always use True - ) - # See https://github.com/k2-fsa/icefall/pull/420 for why - # we always use log_semiring=True - - ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) - best_hyp_indexes = ragged_tot_scores.argmax() - best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes) - - if not return_timestamps: - return get_texts(best_path) - else: - return get_texts_with_timestamp(best_path) - - -def fast_beam_search_nbest( - model: Transducer, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - num_paths: int, - nbest_scale: float = 0.5, - use_double_scores: bool = True, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """It limits the maximum number of symbols per frame to 1. - - The process to get the results is: - - (1) Use fast beam search to get a lattice - - (2) Select `num_paths` paths from the lattice using k2.random_paths() - - (3) Unique the selected paths - - (4) Intersect the selected paths with the lattice and compute the - shortest path from the intersection result - - (5) The path with the largest score is used as the decoding output. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - num_paths: - Number of paths to extract from the decoded lattice. - nbest_scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - use_double_scores: - True to use double precision for computation. False to use - single precision. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - temperature=temperature, - ) - - nbest = Nbest.from_lattice( - lattice=lattice, - num_paths=num_paths, - use_double_scores=use_double_scores, - nbest_scale=nbest_scale, - ) - - # at this point, nbest.fsa.scores are all zeros. - - nbest = nbest.intersect(lattice) - # Now nbest.fsa.scores contains acoustic scores - - max_indexes = nbest.tot_scores().argmax() - - best_path = k2.index_fsa(nbest.fsa, max_indexes) - - if not return_timestamps: - return get_texts(best_path) - else: - return get_texts_with_timestamp(best_path) - - -def fast_beam_search_nbest_oracle( - model: Transducer, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - num_paths: int, - ref_texts: List[List[int]], - use_double_scores: bool = True, - nbest_scale: float = 0.5, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """It limits the maximum number of symbols per frame to 1. - - A lattice is first obtained using fast beam search, and then - we select `num_paths` linear paths from the lattice. The path - that has the minimum edit distance with the given reference transcript - is used as the output. - - This is the best result we can achieve for any nbest based rescoring - methods. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - num_paths: - Number of paths to extract from the decoded lattice. - ref_texts: - A list-of-list of integers containing the reference transcripts. - If the decoding_graph is a trivial_graph, the integer ID is the - BPE token ID. - use_double_scores: - True to use double precision for computation. False to use - single precision. - nbest_scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - temperature=temperature, - ) - - nbest = Nbest.from_lattice( - lattice=lattice, - num_paths=num_paths, - use_double_scores=use_double_scores, - nbest_scale=nbest_scale, - ) - - hyps = nbest.build_levenshtein_graphs() - refs = k2.levenshtein_graph(ref_texts, device=hyps.device) - - levenshtein_alignment = k2.levenshtein_alignment( - refs=refs, - hyps=hyps, - hyp_to_ref_map=nbest.shape.row_ids(1), - sorted_match_ref=True, - ) - - tot_scores = levenshtein_alignment.get_tot_scores( - use_double_scores=False, log_semiring=False - ) - ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) - - max_indexes = ragged_tot_scores.argmax() - - best_path = k2.index_fsa(nbest.fsa, max_indexes) - - if not return_timestamps: - return get_texts(best_path) - else: - return get_texts_with_timestamp(best_path) - - -def fast_beam_search( - model: Transducer, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - temperature: float = 1.0, - subtract_ilme: bool = False, - ilme_scale: float = 0.1, -) -> k2.Fsa: - """It limits the maximum number of symbols per frame to 1. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - temperature: - Softmax temperature. - Returns: - Return an FsaVec with axes [utt][state][arc] containing the decoded - lattice. Note: When the input graph is a TrivialGraph, the returned - lattice is actually an acceptor. - """ - assert encoder_out.ndim == 3 - - context_size = model.decoder.context_size - vocab_size = model.decoder.vocab_size - - B, T, C = encoder_out.shape - - config = k2.RnntDecodingConfig( - vocab_size=vocab_size, - decoder_history_len=context_size, - beam=beam, - max_contexts=max_contexts, - max_states=max_states, - ) - individual_streams = [] - for i in range(B): - individual_streams.append(k2.RnntDecodingStream(decoding_graph)) - decoding_streams = k2.RnntDecodingStreams(individual_streams, config) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - for t in range(T): - # shape is a RaggedShape of shape (B, context) - # contexts is a Tensor of shape (shape.NumElements(), context_size) - shape, contexts = decoding_streams.get_contexts() - # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 - contexts = contexts.to(torch.int64) - # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) - decoder_out = model.decoder(contexts, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - # current_encoder_out is of shape - # (shape.NumElements(), 1, joiner_dim) - # fmt: off - current_encoder_out = torch.index_select( - encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) - ) - # fmt: on - logits = model.joiner( - current_encoder_out.unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - logits = logits.squeeze(1).squeeze(1) - log_probs = (logits / temperature).log_softmax(dim=-1) - if subtract_ilme: - ilme_logits = model.joiner( - torch.zeros_like( - current_encoder_out, device=current_encoder_out.device - ).unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - ilme_logits = ilme_logits.squeeze(1).squeeze(1) - ilme_log_probs = (ilme_logits / temperature).log_softmax(dim=-1) - log_probs -= ilme_scale * ilme_log_probs - decoding_streams.advance(log_probs) - decoding_streams.terminate_and_flush_to_streams() - lattice = decoding_streams.format_output(encoder_out_lens.tolist()) - - return lattice - - -def greedy_search( - model: Transducer, - encoder_out: torch.Tensor, - max_sym_per_frame: int, - return_timestamps: bool = False, -) -> Union[List[int], DecodingResults]: - """Greedy search for a single utterance. - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - max_sym_per_frame: - Maximum number of symbols per frame. If it is set to 0, the WER - would be 100%. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - unk_id = getattr(model, "unk_id", blank_id) - - device = next(model.parameters()).device - - decoder_input = torch.tensor( - [-1] * (context_size - 1) + [blank_id], device=device, dtype=torch.int64 - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - T = encoder_out.size(1) - t = 0 - hyp = [blank_id] * context_size - - # timestamp[i] is the frame index after subsampling - # on which hyp[i] is decoded - timestamp = [] - - # Maximum symbols per utterance. - max_sym_per_utt = 1000 - - # symbols per frame - sym_per_frame = 0 - - # symbols per utterance decoded so far - sym_per_utt = 0 - - while t < T and sym_per_utt < max_sym_per_utt: - if sym_per_frame >= max_sym_per_frame: - sym_per_frame = 0 - t += 1 - continue - - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) - # fmt: on - logits = model.joiner( - current_encoder_out, decoder_out.unsqueeze(1), project_input=False - ) - # logits is (1, 1, 1, vocab_size) - - y = logits.argmax().item() - if y not in (blank_id, unk_id): - hyp.append(y) - timestamp.append(t) - decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( - 1, context_size - ) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - sym_per_utt += 1 - sym_per_frame += 1 - else: - sym_per_frame = 0 - t += 1 - hyp = hyp[context_size:] # remove blanks - - if not return_timestamps: - return hyp - else: - return DecodingResults( - hyps=[hyp], - timestamps=[timestamp], - ) - - -def greedy_search_batch( - model: Transducer, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C), where N >= 1. - encoder_out_lens: - A 1-D tensor of shape (N,), containing number of valid frames in - encoder_out before padding. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - assert encoder_out.ndim == 3 - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - device = next(model.parameters()).device - - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - hyps = [[-1] * (context_size - 1) + [blank_id] for _ in range(N)] - - # timestamp[n][i] is the frame index after subsampling - # on which hyp[n][i] is decoded - timestamps = [[] for _ in range(N)] - # scores[n][i] is the logits on which hyp[n][i] is decoded - scores = [[] for _ in range(N)] - - decoder_input = torch.tensor( - hyps, - device=device, - dtype=torch.int64, - ) # (N, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out: (N, 1, decoder_out_dim) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - for (t, batch_size) in enumerate(batch_size_list): - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) - offset = end - - decoder_out = decoder_out[:batch_size] - - logits = model.joiner( - current_encoder_out, decoder_out.unsqueeze(1), project_input=False - ) - # logits'shape (batch_size, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v not in (blank_id, unk_id): - hyps[i].append(v) - timestamps[i].append(t) - scores[i].append(logits[i, v].item()) - emitted = True - if emitted: - # update decoder output - decoder_input = [h[-context_size:] for h in hyps[:batch_size]] - decoder_input = torch.tensor( - decoder_input, - device=device, - dtype=torch.int64, - ) - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - sorted_ans = [h[context_size:] for h in hyps] - ans = [] - ans_timestamps = [] - ans_scores = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - ans_timestamps.append(timestamps[unsorted_indices[i]]) - ans_scores.append(scores[unsorted_indices[i]]) - - if not return_timestamps: - return ans - else: - return DecodingResults( - hyps=ans, - timestamps=ans_timestamps, - scores=ans_scores, - ) - - -@dataclass -class Hypothesis: - # The predicted tokens so far. - # Newly predicted tokens are appended to `ys`. - ys: List[int] - - # The log prob of ys. - # It contains only one entry. - log_prob: torch.Tensor - - # timestamp[i] is the frame index after subsampling - # on which ys[i] is decoded - timestamp: List[int] = field(default_factory=list) - - # the lm score for next token given the current ys - lm_score: Optional[torch.Tensor] = None - - # the RNNLM states (h and c in LSTM) - state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None - - # N-gram LM state - state_cost: Optional[NgramLmStateCost] = None - - @property - def key(self) -> str: - """Return a string representation of self.ys""" - return "_".join(map(str, self.ys)) - - -class HypothesisList(object): - def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None: - """ - Args: - data: - A dict of Hypotheses. Its key is its `value.key`. - """ - if data is None: - self._data = {} - else: - self._data = data - - @property - def data(self) -> Dict[str, Hypothesis]: - return self._data - - def add(self, hyp: Hypothesis) -> None: - """Add a Hypothesis to `self`. - - If `hyp` already exists in `self`, its probability is updated using - `log-sum-exp` with the existed one. - - Args: - hyp: - The hypothesis to be added. - """ - key = hyp.key - if key in self: - old_hyp = self._data[key] # shallow copy - torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) - else: - self._data[key] = hyp - - def get_most_probable(self, length_norm: bool = False) -> Hypothesis: - """Get the most probable hypothesis, i.e., the one with - the largest `log_prob`. - - Args: - length_norm: - If True, the `log_prob` of a hypothesis is normalized by the - number of tokens in it. - Returns: - Return the hypothesis that has the largest `log_prob`. - """ - if length_norm: - return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) - else: - return max(self._data.values(), key=lambda hyp: hyp.log_prob) - - def remove(self, hyp: Hypothesis) -> None: - """Remove a given hypothesis. - - Caution: - `self` is modified **in-place**. - - Args: - hyp: - The hypothesis to be removed from `self`. - Note: It must be contained in `self`. Otherwise, - an exception is raised. - """ - key = hyp.key - assert key in self, f"{key} does not exist" - del self._data[key] - - def filter(self, threshold: torch.Tensor) -> "HypothesisList": - """Remove all Hypotheses whose log_prob is less than threshold. - - Caution: - `self` is not modified. Instead, a new HypothesisList is returned. - - Returns: - Return a new HypothesisList containing all hypotheses from `self` - with `log_prob` being greater than the given `threshold`. - """ - ans = HypothesisList() - for _, hyp in self._data.items(): - if hyp.log_prob > threshold: - ans.add(hyp) # shallow copy - return ans - - def topk(self, k: int, length_norm: bool = False) -> "HypothesisList": - """Return the top-k hypothesis. - - Args: - length_norm: - If True, the `log_prob` of a hypothesis is normalized by the - number of tokens in it. - """ - hyps = list(self._data.items()) - - if length_norm: - hyps = sorted( - hyps, key=lambda h: h[1].log_prob / len(h[1].ys), reverse=True - )[:k] - else: - hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] - - ans = HypothesisList(dict(hyps)) - return ans - - def __contains__(self, key: str): - return key in self._data - - def __iter__(self): - return iter(self._data.values()) - - def __len__(self) -> int: - return len(self._data) - - def __str__(self) -> str: - s = [] - for key in self: - s.append(key) - return ", ".join(s) - - -def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: - """Return a ragged shape with axes [utt][num_hyps]. - - Args: - hyps: - len(hyps) == batch_size. It contains the current hypothesis for - each utterance in the batch. - Returns: - Return a ragged shape with 2 axes [utt][num_hyps]. Note that - the shape is on CPU. - """ - num_hyps = [len(h) for h in hyps] - - # torch.cumsum() is inclusive sum, so we put a 0 at the beginning - # to get exclusive sum later. - num_hyps.insert(0, 0) - - num_hyps = torch.tensor(num_hyps) - row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32) - ans = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=row_splits[-1].item() - ) - return ans - - -def modified_beam_search( - model: Transducer, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: int = 4, - temperature: float = 1.0, - return_timestamps: bool = False, - use_hat: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. - - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C). - encoder_out_lens: - A 1-D tensor of shape (N,), containing number of valid frames in - encoder_out before padding. - beam: - Number of active paths during the beam search. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - device = next(model.parameters()).device - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - B = [HypothesisList() for _ in range(N)] - for i in range(N): - B[i].add( - Hypothesis( - ys=[blank_id] * context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - timestamp=[], - ) - ) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - finalized_B = [] - for (t, batch_size) in enumerate(batch_size_list): - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - offset = end - - finalized_B = B[batch_size:] + finalized_B - B = B[:batch_size] - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.cat( - [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] - ) # (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) - - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor - # as index, so we use `to(torch.int64)` below. - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - - if use_hat == True: - # For blank symbol, log-prob is log-sigmoid of the score - logp_b = torch.nn.functional.logsigmoid(logits[..., 0]) - # Additionally, to ensure the the probs of blank and non-blank sum to 1, we - # need to add the following term to the log-probs of non-blank symbols. This - # is equivalent to log(1 - sigmoid(logits[..., 0])). - #breakpoint() - nb_shift = logp_b - logits[..., 0] - nb_shift = nb_shift.unsqueeze(-1) - log_probs1 = (logits[..., 1:] / temperature).log_softmax(dim=-1) + nb_shift # (num_hyps, vocab_size-1) - log_probs = torch.cat((logp_b.unsqueeze(-1), log_probs1), dim=-1) - log_probs.add_(ys_log_probs) - else: - log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_ys = hyp.ys[:] - new_token = topk_token_indexes[k] - new_timestamp = hyp.timestamp[:] - if new_token not in (blank_id, unk_id): - new_ys.append(new_token) - new_timestamp.append(t) - - new_log_prob = topk_log_probs[k] - new_hyp = Hypothesis( - ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp - ) - B[i].add(new_hyp) - - B = B + finalized_B - best_hyps = [b.get_most_probable(length_norm=True) for b in B] - - sorted_ans = [h.ys[context_size:] for h in best_hyps] - sorted_timestamps = [h.timestamp for h in best_hyps] - ans = [] - ans_timestamps = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) - - if not return_timestamps: - return ans - else: - return DecodingResults( - hyps=ans, - timestamps=ans_timestamps, - ) - -def modified_beam_search_hat( - model: Transducer, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: int = 4, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. - - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C). - encoder_out_lens: - A 1-D tensor of shape (N,), containing number of valid frames in - encoder_out before padding. - beam: - Number of active paths during the beam search. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - device = next(model.parameters()).device - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - B = [HypothesisList() for _ in range(N)] - for i in range(N): - B[i].add( - Hypothesis( - ys=[blank_id] * context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - timestamp=[], - ) - ) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - finalized_B = [] - for (t, batch_size) in enumerate(batch_size_list): - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - offset = end - - finalized_B = B[batch_size:] + finalized_B - B = B[:batch_size] - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.cat( - [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] - ) # (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) - - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor - # as index, so we use `to(torch.int64)` below. - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - - - # For blank symbol, log-prob is log-sigmoid of the score - logp_b = torch.nn.functional.logsigmoid(logits[..., 0]) - # Additionally, to ensure the the probs of blank and non-blank sum to 1, we - # need to add the following term to the log-probs of non-blank symbols. This - # is equivalent to log(1 - sigmoid(logits[..., 0])). - breakpoint() - nb_shift = logp_b - logits[..., 0] - - log_probs1 = (logits[..., 1:] / temperature).log_softmax(dim=-1) + nb_shift # (num_hyps, vocab_size-1) - log_probs = torch.cat((logp_b, log_probs), dim=-1) - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_ys = hyp.ys[:] - new_token = topk_token_indexes[k] - new_timestamp = hyp.timestamp[:] - if new_token not in (blank_id, unk_id): - new_ys.append(new_token) - new_timestamp.append(t) - - new_log_prob = topk_log_probs[k] - new_hyp = Hypothesis( - ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp - ) - B[i].add(new_hyp) - - B = B + finalized_B - best_hyps = [b.get_most_probable(length_norm=True) for b in B] - - sorted_ans = [h.ys[context_size:] for h in best_hyps] - sorted_timestamps = [h.timestamp for h in best_hyps] - ans = [] - ans_timestamps = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) - - if not return_timestamps: - return ans - else: - return DecodingResults( - hyps=ans, - timestamps=ans_timestamps, - ) - -def modified_beam_search_lm_rescore( - model: Transducer, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - LM: LmScorer, - lm_scale_list: List[int], - beam: int = 4, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. - Rescore the final results with RNNLM and return the one with the highest score - - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C). - encoder_out_lens: - A 1-D tensor of shape (N,), containing number of valid frames in - encoder_out before padding. - beam: - Number of active paths during the beam search. - temperature: - Softmax temperature. - LM: - A neural network language model - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - device = next(model.parameters()).device - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - B = [HypothesisList() for _ in range(N)] - for i in range(N): - B[i].add( - Hypothesis( - ys=[blank_id] * context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - timestamp=[], - ) - ) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - finalized_B = [] - for (t, batch_size) in enumerate(batch_size_list): - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - offset = end - - finalized_B = B[batch_size:] + finalized_B - B = B[:batch_size] - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.cat( - [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] - ) # (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) - - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor - # as index, so we use `to(torch.int64)` below. - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - - log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_ys = hyp.ys[:] - new_token = topk_token_indexes[k] - new_timestamp = hyp.timestamp[:] - if new_token not in (blank_id, unk_id): - new_ys.append(new_token) - new_timestamp.append(t) - - new_log_prob = topk_log_probs[k] - new_hyp = Hypothesis( - ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp - ) - B[i].add(new_hyp) - - B = B + finalized_B - - # get the am_scores for n-best list - hyps_shape = get_hyps_shape(B) - am_scores = torch.tensor([hyp.log_prob.item() for b in B for hyp in b]) - am_scores = k2.RaggedTensor(value=am_scores, shape=hyps_shape).to(device) - - # now LM rescore - # prepare input data to LM - candidate_seqs = [hyp.ys[context_size:] for b in B for hyp in b] - possible_seqs = k2.RaggedTensor(candidate_seqs) - row_splits = possible_seqs.shape.row_splits(1) - sentence_token_lengths = row_splits[1:] - row_splits[:-1] - possible_seqs_with_sos = add_sos(possible_seqs, sos_id=1) - possible_seqs_with_eos = add_eos(possible_seqs, eos_id=1) - sentence_token_lengths += 1 - - x = possible_seqs_with_sos.pad(mode="constant", padding_value=blank_id) - y = possible_seqs_with_eos.pad(mode="constant", padding_value=blank_id) - x = x.to(device).to(torch.int64) - y = y.to(device).to(torch.int64) - sentence_token_lengths = sentence_token_lengths.to(device).to(torch.int64) - - lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths) - assert lm_scores.ndim == 2 - lm_scores = -1 * lm_scores.sum(dim=1) - - ans = {} - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - - # get the best hyp with different lm_scale - for lm_scale in lm_scale_list: - key = f"nnlm_scale_{lm_scale:.2f}" - tot_scores = am_scores.values + lm_scores * lm_scale - ragged_tot_scores = k2.RaggedTensor(shape=am_scores.shape, value=tot_scores) - max_indexes = ragged_tot_scores.argmax().tolist() - unsorted_hyps = [candidate_seqs[idx] for idx in max_indexes] - hyps = [] - for idx in unsorted_indices: - hyps.append(unsorted_hyps[idx]) - - ans[key] = hyps - return ans - - -def modified_beam_search_lm_rescore_LODR( - model: Transducer, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - LM: LmScorer, - LODR_lm: NgramLm, - sp: spm.SentencePieceProcessor, - lm_scale_list: List[int], - beam: int = 4, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. - Rescore the final results with RNNLM and return the one with the highest score - - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C). - encoder_out_lens: - A 1-D tensor of shape (N,), containing number of valid frames in - encoder_out before padding. - beam: - Number of active paths during the beam search. - temperature: - Softmax temperature. - LM: - A neural network language model - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - device = next(model.parameters()).device - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - B = [HypothesisList() for _ in range(N)] - for i in range(N): - B[i].add( - Hypothesis( - ys=[blank_id] * context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - timestamp=[], - ) - ) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - finalized_B = [] - for (t, batch_size) in enumerate(batch_size_list): - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - offset = end - - finalized_B = B[batch_size:] + finalized_B - B = B[:batch_size] - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.cat( - [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] - ) # (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) - - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor - # as index, so we use `to(torch.int64)` below. - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - - log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_ys = hyp.ys[:] - new_token = topk_token_indexes[k] - new_timestamp = hyp.timestamp[:] - if new_token not in (blank_id, unk_id): - new_ys.append(new_token) - new_timestamp.append(t) - - new_log_prob = topk_log_probs[k] - new_hyp = Hypothesis( - ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp - ) - B[i].add(new_hyp) - - B = B + finalized_B - - # get the am_scores for n-best list - hyps_shape = get_hyps_shape(B) - am_scores = torch.tensor([hyp.log_prob.item() for b in B for hyp in b]) - am_scores = k2.RaggedTensor(value=am_scores, shape=hyps_shape).to(device) - - # now LM rescore - # prepare input data to LM - candidate_seqs = [hyp.ys[context_size:] for b in B for hyp in b] - possible_seqs = k2.RaggedTensor(candidate_seqs) - row_splits = possible_seqs.shape.row_splits(1) - sentence_token_lengths = row_splits[1:] - row_splits[:-1] - possible_seqs_with_sos = add_sos(possible_seqs, sos_id=1) - possible_seqs_with_eos = add_eos(possible_seqs, eos_id=1) - sentence_token_lengths += 1 - - x = possible_seqs_with_sos.pad(mode="constant", padding_value=blank_id) - y = possible_seqs_with_eos.pad(mode="constant", padding_value=blank_id) - x = x.to(device).to(torch.int64) - y = y.to(device).to(torch.int64) - sentence_token_lengths = sentence_token_lengths.to(device).to(torch.int64) - - lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths) - assert lm_scores.ndim == 2 - lm_scores = -1 * lm_scores.sum(dim=1) - - # now LODR scores - import math - - LODR_scores = [] - for seq in candidate_seqs: - tokens = " ".join(sp.id_to_piece(seq)) - LODR_scores.append(LODR_lm.score(tokens)) - LODR_scores = torch.tensor(LODR_scores).to(device) * math.log( - 10 - ) # arpa scores are 10-based - assert lm_scores.shape == LODR_scores.shape - - ans = {} - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - - LODR_scale_list = [0.05 * i for i in range(1, 20)] - # get the best hyp with different lm_scale and lodr_scale - for lm_scale in lm_scale_list: - for lodr_scale in LODR_scale_list: - key = f"nnlm_scale_{lm_scale:.2f}_lodr_scale_{lodr_scale:.2f}" - tot_scores = ( - am_scores.values / lm_scale + lm_scores - LODR_scores * lodr_scale - ) - ragged_tot_scores = k2.RaggedTensor(shape=am_scores.shape, value=tot_scores) - max_indexes = ragged_tot_scores.argmax().tolist() - unsorted_hyps = [candidate_seqs[idx] for idx in max_indexes] - hyps = [] - for idx in unsorted_indices: - hyps.append(unsorted_hyps[idx]) - - ans[key] = hyps - return ans - - -def _deprecated_modified_beam_search( - model: Transducer, - encoder_out: torch.Tensor, - beam: int = 4, - return_timestamps: bool = False, -) -> Union[List[int], DecodingResults]: - """It limits the maximum number of symbols per frame to 1. - - It decodes only one utterance at a time. We keep it only for reference. - The function :func:`modified_beam_search` should be preferred as it - supports batch decoding. - - - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - beam: - Beam size. - return_timestamps: - Whether to return timestamps. - - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - - device = next(model.parameters()).device - - T = encoder_out.size(1) - - B = HypothesisList() - B.add( - Hypothesis( - ys=[blank_id] * context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - timestamp=[], - ) - ) - encoder_out = model.joiner.encoder_proj(encoder_out) - - for t in range(T): - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) - # current_encoder_out is of shape (1, 1, 1, encoder_out_dim) - # fmt: on - A = list(B) - B = HypothesisList() - - ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A]) - # ys_log_probs is of shape (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyp in A], - device=device, - dtype=torch.int64, - ) - # decoder_input is of shape (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_output is of shape (num_hyps, 1, 1, joiner_dim) - - current_encoder_out = current_encoder_out.expand( - decoder_out.size(0), 1, 1, -1 - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) - # logits is of shape (num_hyps, 1, 1, vocab_size) - logits = logits.squeeze(1).squeeze(1) - - # now logits is of shape (num_hyps, vocab_size) - log_probs = logits.log_softmax(dim=-1) - - log_probs.add_(ys_log_probs) - - log_probs = log_probs.reshape(-1) - topk_log_probs, topk_indexes = log_probs.topk(beam) - - # topk_hyp_indexes are indexes into `A` - topk_hyp_indexes = topk_indexes // logits.size(-1) - topk_token_indexes = topk_indexes % logits.size(-1) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = topk_hyp_indexes.tolist() - topk_token_indexes = topk_token_indexes.tolist() - - for i in range(len(topk_hyp_indexes)): - hyp = A[topk_hyp_indexes[i]] - new_ys = hyp.ys[:] - new_timestamp = hyp.timestamp[:] - new_token = topk_token_indexes[i] - if new_token not in (blank_id, unk_id): - new_ys.append(new_token) - new_timestamp.append(t) - new_log_prob = topk_log_probs[i] - new_hyp = Hypothesis( - ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp - ) - B.add(new_hyp) - - best_hyp = B.get_most_probable(length_norm=True) - ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks - - if not return_timestamps: - return ys - else: - return DecodingResults(hyps=[ys], timestamps=[best_hyp.timestamp]) - - -def beam_search( - model: Transducer, - encoder_out: torch.Tensor, - beam: int = 4, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[int], DecodingResults]: - """ - It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf - - espnet/nets/beam_search_transducer.py#L247 is used as a reference. - - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - beam: - Beam size. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - - device = next(model.parameters()).device - - decoder_input = torch.tensor( - [blank_id] * context_size, - device=device, - dtype=torch.int64, - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - T = encoder_out.size(1) - t = 0 - - B = HypothesisList() - B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0, timestamp=[])) - - max_sym_per_utt = 20000 - - sym_per_utt = 0 - - decoder_cache: Dict[str, torch.Tensor] = {} - - while t < T and sym_per_utt < max_sym_per_utt: - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) - # fmt: on - A = B - B = HypothesisList() - - joint_cache: Dict[str, torch.Tensor] = {} - - # TODO(fangjun): Implement prefix search to update the `log_prob` - # of hypotheses in A - - while True: - y_star = A.get_most_probable() - A.remove(y_star) - - cached_key = y_star.key - - if cached_key not in decoder_cache: - decoder_input = torch.tensor( - [y_star.ys[-context_size:]], - device=device, - dtype=torch.int64, - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - decoder_cache[cached_key] = decoder_out - else: - decoder_out = decoder_cache[cached_key] - - cached_key += f"-t-{t}" - if cached_key not in joint_cache: - logits = model.joiner( - current_encoder_out, - decoder_out.unsqueeze(1), - project_input=False, - ) - - # TODO(fangjun): Scale the blank posterior - log_prob = (logits / temperature).log_softmax(dim=-1) - # log_prob is (1, 1, 1, vocab_size) - log_prob = log_prob.squeeze() - # Now log_prob is (vocab_size,) - joint_cache[cached_key] = log_prob - else: - log_prob = joint_cache[cached_key] - - # First, process the blank symbol - skip_log_prob = log_prob[blank_id] - new_y_star_log_prob = y_star.log_prob + skip_log_prob - - # ys[:] returns a copy of ys - B.add( - Hypothesis( - ys=y_star.ys[:], - log_prob=new_y_star_log_prob, - timestamp=y_star.timestamp[:], - ) - ) - - # Second, process other non-blank labels - values, indices = log_prob.topk(beam + 1) - for i, v in zip(indices.tolist(), values.tolist()): - if i in (blank_id, unk_id): - continue - new_ys = y_star.ys + [i] - new_log_prob = y_star.log_prob + v - new_timestamp = y_star.timestamp + [t] - A.add( - Hypothesis( - ys=new_ys, - log_prob=new_log_prob, - timestamp=new_timestamp, - ) - ) - - # Check whether B contains more than "beam" elements more probable - # than the most probable in A - A_most_probable = A.get_most_probable() - - kept_B = B.filter(A_most_probable.log_prob) - - if len(kept_B) >= beam: - B = kept_B.topk(beam) - break - - t += 1 - - best_hyp = B.get_most_probable(length_norm=True) - ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks - - if not return_timestamps: - return ys - else: - return DecodingResults(hyps=[ys], timestamps=[best_hyp.timestamp]) - - -def fast_beam_search_with_nbest_rescoring( - model: Transducer, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - ngram_lm_scale_list: List[float], - num_paths: int, - G: k2.Fsa, - sp: spm.SentencePieceProcessor, - word_table: k2.SymbolTable, - oov_word: str = "", - use_double_scores: bool = True, - nbest_scale: float = 0.5, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Dict[str, Union[List[List[int]], DecodingResults]]: - """It limits the maximum number of symbols per frame to 1. - A lattice is first obtained using fast beam search, num_path are selected - and rescored using a given language model. The shortest path within the - lattice is used as the final output. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - ngram_lm_scale_list: - A list of floats representing LM score scales. - num_paths: - Number of paths to extract from the decoded lattice. - G: - An FsaVec containing only a single FSA. It is an n-gram LM. - sp: - The BPE model. - word_table: - The word symbol table. - oov_word: - OOV words are replaced with this word. - use_double_scores: - True to use double precision for computation. False to use - single precision. - nbest_scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - Return the decoded result in a dict, where the key has the form - 'ngram_lm_scale_xx' and the value is the decoded results - optionally with timestamps. `xx` is the ngram LM scale value - used during decoding, i.e., 0.1. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - temperature=temperature, - ) - - nbest = Nbest.from_lattice( - lattice=lattice, - num_paths=num_paths, - use_double_scores=use_double_scores, - nbest_scale=nbest_scale, - ) - # at this point, nbest.fsa.scores are all zeros. - - nbest = nbest.intersect(lattice) - # Now nbest.fsa.scores contains acoustic scores - - am_scores = nbest.tot_scores() - - # Now we need to compute the LM scores of each path. - # (1) Get the token IDs of each Path. We assume the decoding_graph - # is an acceptor, i.e., lattice is also an acceptor - tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) # [path][arc] - - tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.labels.contiguous()) - tokens = tokens.remove_values_leq(0) # remove -1 and 0 - - token_list: List[List[int]] = tokens.tolist() - word_list: List[List[str]] = sp.decode(token_list) - - assert isinstance(oov_word, str), oov_word - assert oov_word in word_table, oov_word - oov_word_id = word_table[oov_word] - - word_ids_list: List[List[int]] = [] - - for words in word_list: - this_word_ids = [] - for w in words.split(): - if w in word_table: - this_word_ids.append(word_table[w]) - else: - this_word_ids.append(oov_word_id) - word_ids_list.append(this_word_ids) - - word_fsas = k2.linear_fsa(word_ids_list, device=lattice.device) - word_fsas_with_self_loops = k2.add_epsilon_self_loops(word_fsas) - - num_unique_paths = len(word_ids_list) - - b_to_a_map = torch.zeros( - num_unique_paths, - dtype=torch.int32, - device=lattice.device, - ) - - rescored_word_fsas = k2.intersect_device( - a_fsas=G, - b_fsas=word_fsas_with_self_loops, - b_to_a_map=b_to_a_map, - sorted_match_a=True, - ret_arc_maps=False, - ) - - rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas) - rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas)) - ngram_lm_scores = rescored_word_fsas.get_tot_scores( - use_double_scores=True, - log_semiring=False, - ) - - ans: Dict[str, Union[List[List[int]], DecodingResults]] = {} - for s in ngram_lm_scale_list: - key = f"ngram_lm_scale_{s}" - tot_scores = am_scores.values + s * ngram_lm_scores - ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) - max_indexes = ragged_tot_scores.argmax() - best_path = k2.index_fsa(nbest.fsa, max_indexes) - - if not return_timestamps: - ans[key] = get_texts(best_path) - else: - ans[key] = get_texts_with_timestamp(best_path) - - return ans - - -def fast_beam_search_with_nbest_rnn_rescoring( - model: Transducer, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - ngram_lm_scale_list: List[float], - num_paths: int, - G: k2.Fsa, - sp: spm.SentencePieceProcessor, - word_table: k2.SymbolTable, - rnn_lm_model: torch.nn.Module, - rnn_lm_scale_list: List[float], - oov_word: str = "", - use_double_scores: bool = True, - nbest_scale: float = 0.5, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Dict[str, Union[List[List[int]], DecodingResults]]: - """It limits the maximum number of symbols per frame to 1. - A lattice is first obtained using fast beam search, num_path are selected - and rescored using a given language model and a rnn-lm. - The shortest path within the lattice is used as the final output. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - ngram_lm_scale_list: - A list of floats representing LM score scales. - num_paths: - Number of paths to extract from the decoded lattice. - G: - An FsaVec containing only a single FSA. It is an n-gram LM. - sp: - The BPE model. - word_table: - The word symbol table. - rnn_lm_model: - A rnn-lm model used for LM rescoring - rnn_lm_scale_list: - A list of floats representing RNN score scales. - oov_word: - OOV words are replaced with this word. - use_double_scores: - True to use double precision for computation. False to use - single precision. - nbest_scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - Return the decoded result in a dict, where the key has the form - 'ngram_lm_scale_xx' and the value is the decoded results - optionally with timestamps. `xx` is the ngram LM scale value - used during decoding, i.e., 0.1. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - temperature=temperature, - ) - - nbest = Nbest.from_lattice( - lattice=lattice, - num_paths=num_paths, - use_double_scores=use_double_scores, - nbest_scale=nbest_scale, - ) - # at this point, nbest.fsa.scores are all zeros. - - nbest = nbest.intersect(lattice) - # Now nbest.fsa.scores contains acoustic scores - - am_scores = nbest.tot_scores() - - # Now we need to compute the LM scores of each path. - # (1) Get the token IDs of each Path. We assume the decoding_graph - # is an acceptor, i.e., lattice is also an acceptor - tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) # [path][arc] - - tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.labels.contiguous()) - tokens = tokens.remove_values_leq(0) # remove -1 and 0 - - token_list: List[List[int]] = tokens.tolist() - word_list: List[List[str]] = sp.decode(token_list) - - assert isinstance(oov_word, str), oov_word - assert oov_word in word_table, oov_word - oov_word_id = word_table[oov_word] - - word_ids_list: List[List[int]] = [] - - for words in word_list: - this_word_ids = [] - for w in words.split(): - if w in word_table: - this_word_ids.append(word_table[w]) - else: - this_word_ids.append(oov_word_id) - word_ids_list.append(this_word_ids) - - word_fsas = k2.linear_fsa(word_ids_list, device=lattice.device) - word_fsas_with_self_loops = k2.add_epsilon_self_loops(word_fsas) - - num_unique_paths = len(word_ids_list) - - b_to_a_map = torch.zeros( - num_unique_paths, - dtype=torch.int32, - device=lattice.device, - ) - - rescored_word_fsas = k2.intersect_device( - a_fsas=G, - b_fsas=word_fsas_with_self_loops, - b_to_a_map=b_to_a_map, - sorted_match_a=True, - ret_arc_maps=False, - ) - - rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas) - rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas)) - ngram_lm_scores = rescored_word_fsas.get_tot_scores( - use_double_scores=True, - log_semiring=False, - ) - - # Now RNN-LM - blank_id = model.decoder.blank_id - sos_id = sp.piece_to_id("sos_id") - eos_id = sp.piece_to_id("eos_id") - - sos_tokens = add_sos(tokens, sos_id) - tokens_eos = add_eos(tokens, eos_id) - sos_tokens_row_splits = sos_tokens.shape.row_splits(1) - sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1] - - x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id) - y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id) - - x_tokens = x_tokens.to(torch.int64) - y_tokens = y_tokens.to(torch.int64) - sentence_lengths = sentence_lengths.to(torch.int64) - - rnn_lm_nll = rnn_lm_model(x=x_tokens, y=y_tokens, lengths=sentence_lengths) - assert rnn_lm_nll.ndim == 2 - assert rnn_lm_nll.shape[0] == len(token_list) - rnn_lm_scores = -1 * rnn_lm_nll.sum(dim=1) - - ans: Dict[str, List[List[int]]] = {} - for n_scale in ngram_lm_scale_list: - for rnn_scale in rnn_lm_scale_list: - key = f"ngram_lm_scale_{n_scale}_rnn_lm_scale_{rnn_scale}" - tot_scores = ( - am_scores.values + n_scale * ngram_lm_scores + rnn_scale * rnn_lm_scores - ) - ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) - max_indexes = ragged_tot_scores.argmax() - best_path = k2.index_fsa(nbest.fsa, max_indexes) - - if not return_timestamps: - ans[key] = get_texts(best_path) - else: - ans[key] = get_texts_with_timestamp(best_path) - - return ans - - -def modified_beam_search_ngram_rescoring( - model: Transducer, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - ngram_lm: NgramLm, - ngram_lm_scale: float, - beam: int = 4, - temperature: float = 1.0, -) -> List[List[int]]: - """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. - - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C). - encoder_out_lens: - A 1-D tensor of shape (N,), containing number of valid frames in - encoder_out before padding. - beam: - Number of active paths during the beam search. - temperature: - Softmax temperature. - Returns: - Return a list-of-list of token IDs. ans[i] is the decoding results - for the i-th utterance. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - device = next(model.parameters()).device - lm_scale = ngram_lm_scale - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - B = [HypothesisList() for _ in range(N)] - for i in range(N): - B[i].add( - Hypothesis( - ys=[blank_id] * context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - state_cost=NgramLmStateCost(ngram_lm), - ) - ) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - finalized_B = [] - for batch_size in batch_size_list: - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - offset = end - - finalized_B = B[batch_size:] + finalized_B - B = B[:batch_size] - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.cat( - [ - hyp.log_prob.reshape(1, 1) + hyp.state_cost.lm_score * lm_scale - for hyps in A - for hyp in hyps - ] - ) # (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) - - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor - # as index, so we use `to(torch.int64)` below. - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - - log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - vocab_size = log_probs.size(-1) - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_ys = hyp.ys[:] - new_token = topk_token_indexes[k] - if new_token not in (blank_id, unk_id): - new_ys.append(new_token) - state_cost = hyp.state_cost.forward_one_step(new_token) - else: - state_cost = hyp.state_cost - - # We only keep AM scores in new_hyp.log_prob - new_log_prob = topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale - - new_hyp = Hypothesis( - ys=new_ys, log_prob=new_log_prob, state_cost=state_cost - ) - B[i].add(new_hyp) - - B = B + finalized_B - best_hyps = [b.get_most_probable(length_norm=True) for b in B] - - sorted_ans = [h.ys[context_size:] for h in best_hyps] - ans = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - - return ans - - -def modified_beam_search_LODR( - model: Transducer, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - LODR_lm: NgramLm, - LODR_lm_scale: float, - LM: LmScorer, - beam: int = 4, -) -> List[List[int]]: - """This function implements LODR (https://arxiv.org/abs/2203.16776) with - `modified_beam_search`. It uses a bi-gram language model as the estimate - of the internal language model and subtracts its score during shallow fusion - with an external language model. This implementation uses a RNNLM as the - external language model. - - Args: - model (Transducer): - The transducer model - encoder_out (torch.Tensor): - Encoder output in (N,T,C) - encoder_out_lens (torch.Tensor): - A 1-D tensor of shape (N,), containing the number of - valid frames in encoder_out before padding. - LODR_lm: - A low order n-gram LM, whose score will be subtracted during shallow fusion - LODR_lm_scale: - The scale of the LODR_lm - LM: - A neural net LM, e.g an RNNLM or transformer LM - beam (int, optional): - Beam size. Defaults to 4. - - Returns: - Return a list-of-list of token IDs. ans[i] is the decoding results - for the i-th utterance. - - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - assert LM is not None - lm_scale = LM.lm_scale - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = model.decoder.blank_id - sos_id = getattr(LM, "sos_id", 1) - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - device = next(model.parameters()).device - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - # get initial lm score and lm state by scoring the "sos" token - sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) - lens = torch.tensor([1]).to(device) - init_score, init_states = LM.score_token(sos_token, lens) - - B = [HypothesisList() for _ in range(N)] - for i in range(N): - B[i].add( - Hypothesis( - ys=[blank_id] * context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - state=init_states, # state of the NN LM - lm_score=init_score.reshape(-1), - state_cost=NgramLmStateCost( - LODR_lm - ), # state of the source domain ngram - ) - ) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - finalized_B = [] - for batch_size in batch_size_list: - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] # get batch - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - offset = end - - finalized_B = B[batch_size:] + finalized_B - B = B[:batch_size] - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.cat( - [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] - ) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - - log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - """ - for all hyps with a non-blank new token, score this token. - It is a little confusing here because this for-loop - looks very similar to the one below. Here, we go through all - top-k tokens and only add the non-blanks ones to the token_list. - LM will score those tokens given the LM states. Note that - the variable `scores` is the LM score after seeing the new - non-blank token. - """ - token_list = [] - hs = [] - cs = [] - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_token = topk_token_indexes[k] - if new_token not in (blank_id, unk_id): - if LM.lm_type == "rnn": - token_list.append([new_token]) - # store the LSTM states - hs.append(hyp.state[0]) - cs.append(hyp.state[1]) - else: - # for transformer LM - token_list.append( - [sos_id] + hyp.ys[context_size:] + [new_token] - ) - - # forward NN LM to get new states and scores - if len(token_list) != 0: - x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device) - if LM.lm_type == "rnn": - tokens_to_score = ( - torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) - ) - hs = torch.cat(hs, dim=1).to(device) - cs = torch.cat(cs, dim=1).to(device) - state = (hs, cs) - else: - # for transformer LM - tokens_list = [torch.tensor(tokens) for tokens in token_list] - tokens_to_score = ( - torch.nn.utils.rnn.pad_sequence( - tokens_list, batch_first=True, padding_value=0.0 - ) - .to(device) - .to(torch.int64) - ) - - state = None - - scores, lm_states = LM.score_token(tokens_to_score, x_lens, state) - - count = 0 # index, used to locate score and lm states - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - ys = hyp.ys[:] - - # current score of hyp - lm_score = hyp.lm_score - state = hyp.state - - hyp_log_prob = topk_log_probs[k] # get score of current hyp - new_token = topk_token_indexes[k] - if new_token not in (blank_id, unk_id): - - ys.append(new_token) - state_cost = hyp.state_cost.forward_one_step(new_token) - - # calculate the score of the latest token - current_ngram_score = state_cost.lm_score - hyp.state_cost.lm_score - - assert current_ngram_score <= 0.0, ( - state_cost.lm_score, - hyp.state_cost.lm_score, - ) - # score = score + TDLM_score - LODR_score - # LODR_LM_scale should be a negative number here - hyp_log_prob += ( - lm_score[new_token] * lm_scale - + LODR_lm_scale * current_ngram_score - ) # add the lm score - - lm_score = scores[count] - if LM.lm_type == "rnn": - state = ( - lm_states[0][:, count, :].unsqueeze(1), - lm_states[1][:, count, :].unsqueeze(1), - ) - count += 1 - else: - state_cost = hyp.state_cost - - new_hyp = Hypothesis( - ys=ys, - log_prob=hyp_log_prob, - state=state, - lm_score=lm_score, - state_cost=state_cost, - ) - B[i].add(new_hyp) - - B = B + finalized_B - best_hyps = [b.get_most_probable(length_norm=True) for b in B] - - sorted_ans = [h.ys[context_size:] for h in best_hyps] - ans = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - - return ans - - -def modified_beam_search_lm_shallow_fusion( - model: Transducer, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - LM: LmScorer, - beam: int = 4, - return_timestamps: bool = False, -) -> List[List[int]]: - """Modified_beam_search + NN LM shallow fusion - - Args: - model (Transducer): - The transducer model - encoder_out (torch.Tensor): - Encoder output in (N,T,C) - encoder_out_lens (torch.Tensor): - A 1-D tensor of shape (N,), containing the number of - valid frames in encoder_out before padding. - sp: - Sentence piece generator. - LM (LmScorer): - A neural net LM, e.g RNN or Transformer - beam (int, optional): - Beam size. Defaults to 4. - - Returns: - Return a list-of-list of token IDs. ans[i] is the decoding results - for the i-th utterance. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - assert LM is not None - lm_scale = LM.lm_scale - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = model.decoder.blank_id - sos_id = getattr(LM, "sos_id", 1) - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - device = next(model.parameters()).device - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - # get initial lm score and lm state by scoring the "sos" token - sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) - lens = torch.tensor([1]).to(device) - init_score, init_states = LM.score_token(sos_token, lens) - - B = [HypothesisList() for _ in range(N)] - for i in range(N): - B[i].add( - Hypothesis( - ys=[blank_id] * context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - state=init_states, - lm_score=init_score.reshape(-1), - timestamp=[], - ) - ) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - finalized_B = [] - for (t, batch_size) in enumerate(batch_size_list): - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] # get batch - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - offset = end - - finalized_B = B[batch_size:] + finalized_B - B = B[:batch_size] - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.cat( - [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] - ) - - lm_scores = torch.cat( - [hyp.lm_score.reshape(1, -1) for hyps in A for hyp in hyps] - ) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - - log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - """ - for all hyps with a non-blank new token, score this token. - It is a little confusing here because this for-loop - looks very similar to the one below. Here, we go through all - top-k tokens and only add the non-blanks ones to the token_list. - `LM` will score those tokens given the LM states. Note that - the variable `scores` is the LM score after seeing the new - non-blank token. - """ - token_list = [] # a list of list - hs = [] - cs = [] - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_token = topk_token_indexes[k] - if new_token not in (blank_id, unk_id): - if LM.lm_type == "rnn": - token_list.append([new_token]) - # store the LSTM states - hs.append(hyp.state[0]) - cs.append(hyp.state[1]) - else: - # for transformer LM - token_list.append( - [sos_id] + hyp.ys[context_size:] + [new_token] - ) - - if len(token_list) != 0: - x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device) - if LM.lm_type == "rnn": - tokens_to_score = ( - torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) - ) - hs = torch.cat(hs, dim=1).to(device) - cs = torch.cat(cs, dim=1).to(device) - state = (hs, cs) - else: - # for transformer LM - tokens_list = [torch.tensor(tokens) for tokens in token_list] - tokens_to_score = ( - torch.nn.utils.rnn.pad_sequence( - tokens_list, batch_first=True, padding_value=0.0 - ) - .to(device) - .to(torch.int64) - ) - - state = None - - scores, lm_states = LM.score_token(tokens_to_score, x_lens, state) - - count = 0 # index, used to locate score and lm states - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - ys = hyp.ys[:] - - lm_score = hyp.lm_score - state = hyp.state - - hyp_log_prob = topk_log_probs[k] # get score of current hyp - new_token = topk_token_indexes[k] - new_timestamp = hyp.timestamp[:] - if new_token not in (blank_id, unk_id): - - ys.append(new_token) - new_timestamp.append(t) - - hyp_log_prob += lm_score[new_token] * lm_scale # add the lm score - - lm_score = scores[count] - if LM.lm_type == "rnn": - state = ( - lm_states[0][:, count, :].unsqueeze(1), - lm_states[1][:, count, :].unsqueeze(1), - ) - count += 1 - - new_hyp = Hypothesis( - ys=ys, - log_prob=hyp_log_prob, - state=state, - lm_score=lm_score, - timestamp=new_timestamp, - ) - B[i].add(new_hyp) - - B = B + finalized_B - best_hyps = [b.get_most_probable(length_norm=True) for b in B] - - sorted_ans = [h.ys[context_size:] for h in best_hyps] - sorted_timestamps = [h.timestamp for h in best_hyps] - ans = [] - ans_timestamps = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) - - if not return_timestamps: - return ans - else: - return DecodingResults( - hyps=ans, - timestamps=ans_timestamps, - ) diff --git a/egs/iwslt22_ta/ASR/zipformer/beam_search.py b/egs/iwslt22_ta/ASR/zipformer/beam_search.py new file mode 120000 index 000000000..8e2c0a65c --- /dev/null +++ b/egs/iwslt22_ta/ASR/zipformer/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/beam_search.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/zipformer/decoder.py b/egs/iwslt22_ta/ASR/zipformer/decoder.py deleted file mode 100644 index 45432d570..000000000 --- a/egs/iwslt22_ta/ASR/zipformer/decoder.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from scaling import Balancer - - -class Decoder(nn.Module): - """This class modifies the stateless decoder from the following paper: - - RNN-transducer with stateless prediction network - https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 - - It removes the recurrent connection from the decoder, i.e., the prediction - network. Different from the above paper, it adds an extra Conv1d - right after the embedding layer. - - TODO: Implement https://arxiv.org/pdf/2109.07513.pdf - """ - - def __init__( - self, - vocab_size: int, - decoder_dim: int, - blank_id: int, - context_size: int, - ): - """ - Args: - vocab_size: - Number of tokens of the modeling unit including blank. - decoder_dim: - Dimension of the input embedding, and of the decoder output. - blank_id: - The ID of the blank symbol. - context_size: - Number of previous words to use to predict the next word. - 1 means bigram; 2 means trigram. n means (n+1)-gram. - """ - super().__init__() - - self.embedding = nn.Embedding( - num_embeddings=vocab_size, - embedding_dim=decoder_dim, - padding_idx=blank_id, - ) - # the balancers are to avoid any drift in the magnitude of the - # embeddings, which would interact badly with parameter averaging. - self.balancer = Balancer(decoder_dim, channel_dim=-1, - min_positive=0.0, max_positive=1.0, - min_abs=0.5, max_abs=1.0, - prob=0.05) - - self.blank_id = blank_id - - assert context_size >= 1, context_size - self.context_size = context_size - self.vocab_size = vocab_size - - if context_size > 1: - self.conv = nn.Conv1d( - in_channels=decoder_dim, - out_channels=decoder_dim, - kernel_size=context_size, - padding=0, - groups=decoder_dim // 4, # group size == 4 - bias=False, - ) - self.balancer2 = Balancer(decoder_dim, channel_dim=-1, - min_positive=0.0, max_positive=1.0, - min_abs=0.5, max_abs=1.0, - prob=0.05) - - def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: - """ - Args: - y: - A 2-D tensor of shape (N, U). - need_pad: - True to left pad the input. Should be True during training. - False to not pad the input. Should be False during inference. - Returns: - Return a tensor of shape (N, U, decoder_dim). - """ - y = y.to(torch.int64) - # this stuff about clamp() is a temporary fix for a mismatch - # at utterance start, we use negative ids in beam_search.py - embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) - - embedding_out = self.balancer(embedding_out) - - if self.context_size > 1: - embedding_out = embedding_out.permute(0, 2, 1) - if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) - else: - # During inference time, there is no need to do extra padding - # as we only need one output - assert embedding_out.size(-1) == self.context_size - embedding_out = self.conv(embedding_out) - embedding_out = embedding_out.permute(0, 2, 1) - embedding_out = F.relu(embedding_out) - embedding_out = self.balancer2(embedding_out) - - return embedding_out diff --git a/egs/iwslt22_ta/ASR/zipformer/decoder.py b/egs/iwslt22_ta/ASR/zipformer/decoder.py new file mode 120000 index 000000000..cefc9926e --- /dev/null +++ b/egs/iwslt22_ta/ASR/zipformer/decoder.py @@ -0,0 +1 @@ +../../ST/zipformer/decoder.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/zipformer/export.py b/egs/iwslt22_ta/ASR/zipformer/export.py deleted file mode 100755 index b996470aa..000000000 --- a/egs/iwslt22_ta/ASR/zipformer/export.py +++ /dev/null @@ -1,523 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" - -Usage: - -(1) Export to torchscript model using torch.jit.script() - -- For non-streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 30 \ - --avg 9 \ - --jit 1 - -It will generate a file `jit_script.pt` in the given `exp_dir`. You can later -load it by `torch.jit.load("jit_script.pt")`. - -Check ./jit_pretrained.py for its usage. - -Check https://github.com/k2-fsa/sherpa -for how to use the exported models outside of icefall. - -- For streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 30 \ - --avg 9 \ - --jit 1 - -It will generate a file `jit_script_chunk_16_left_128.pt` in the given `exp_dir`. -You can later load it by `torch.jit.load("jit_script_chunk_16_left_128.pt")`. - -Check ./jit_pretrained_streaming.py for its usage. - -Check https://github.com/k2-fsa/sherpa -for how to use the exported models outside of icefall. - -(2) Export `model.state_dict()` - -- For non-streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 30 \ - --avg 9 - -- For streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --causal 1 \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 30 \ - --avg 9 - -It will generate a file `pretrained.pt` in the given `exp_dir`. You can later -load it by `icefall.checkpoint.load_checkpoint()`. - -- For non-streaming model: - -To use the generated file with `zipformer/decode.py`, -you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/librispeech/ASR - ./zipformer/decode.py \ - --exp-dir ./zipformer/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 \ - --decoding-method greedy_search \ - --bpe-model data/lang_bpe_500/bpe.model - -- For streaming model: - -To use the generated file with `zipformer/decode.py` and `zipformer/streaming_decode.py`, you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/librispeech/ASR - - # simulated streaming decoding - ./zipformer/decode.py \ - --exp-dir ./zipformer/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --decoding-method greedy_search \ - --bpe-model data/lang_bpe_500/bpe.model - - # chunk-wise streaming decoding - ./zipformer/streaming_decode.py \ - --exp-dir ./zipformer/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --decoding-method greedy_search \ - --bpe-model data/lang_bpe_500/bpe.model - -Check ./pretrained.py for its usage. - -Note: If you don't want to train a model from scratch, we have -provided one for you. You can get it at - -- non-streaming model: -https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 - -- streaming model: -https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 - -with the following commands: - - sudo apt-get install git-lfs - git lfs install - git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 - git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 - # You will find the pre-trained models in exp dir -""" - -import argparse -import logging -from pathlib import Path -from typing import List, Tuple - -import sentencepiece as spm -import torch -from torch import Tensor, nn -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import make_pad_mask, str2bool -from scaling_converter import convert_scaled_to_non_scaled - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - It will generate a file named cpu_jit.pt. - Check ./jit_pretrained.py for how to use it. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -class EncoderModel(nn.Module): - """A wrapper for encoder and encoder_embed""" - def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: - super().__init__() - self.encoder = encoder - self.encoder_embed = encoder_embed - - def forward( - self, features: Tensor, feature_lengths: Tensor - ) -> Tuple[Tensor, Tensor]: - """ - Args: - features: (N, T, C) - feature_lengths: (N,) - """ - x, x_lens = self.encoder_embed(features, feature_lengths) - - src_key_padding_mask = make_pad_mask(x_lens) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - encoder_out, encoder_out_lens = self.encoder( - x, x_lens, src_key_padding_mask - ) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - return encoder_out, encoder_out_lens - - -class StreamingEncoderModel(nn.Module): - """A wrapper for encoder and encoder_embed""" - - def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: - super().__init__() - assert len(encoder.chunk_size) == 1, encoder.chunk_size - assert len(encoder.left_context_frames) == 1, encoder.left_context_frames - self.chunk_size = encoder.chunk_size[0] - self.left_context_len = encoder.left_context_frames[0] - - # The encoder_embed subsample features (T - 7) // 2 - # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling - self.pad_length = 7 + 2 * 3 - - self.encoder = encoder - self.encoder_embed = encoder_embed - - def forward( - self, features: Tensor, feature_lengths: Tensor, states: List[Tensor] - ) -> Tuple[Tensor, Tensor, List[Tensor]]: - """Streaming forward for encoder_embed and encoder. - - Args: - features: (N, T, C) - feature_lengths: (N,) - states: a list of Tensors - - Returns encoder outputs, output lengths, and updated states. - """ - chunk_size = self.chunk_size - left_context_len = self.left_context_len - - cached_embed_left_pad = states[-2] - x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward( - x=features, - x_lens=feature_lengths, - cached_left_pad=cached_embed_left_pad, - ) - assert x.size(1) == chunk_size, (x.size(1), chunk_size) - - src_key_padding_mask = make_pad_mask(x_lens) - - # processed_mask is used to mask out initial states - processed_mask = torch.arange(left_context_len, device=x.device).expand( - x.size(0), left_context_len - ) - processed_lens = states[-1] # (batch,) - # (batch, left_context_size) - processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) - # Update processed lengths - new_processed_lens = processed_lens + x_lens - - # (batch, left_context_size + chunk_size) - src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) - - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_states = states[:-2] - - ( - encoder_out, - encoder_out_lens, - new_encoder_states, - ) = self.encoder.streaming_forward( - x=x, - x_lens=x_lens, - states=encoder_states, - src_key_padding_mask=src_key_padding_mask, - ) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - new_states = new_encoder_states + [ - new_cached_embed_left_pad, - new_processed_lens, - ] - return encoder_out, encoder_out_lens, new_states - - @torch.jit.export - def get_init_states( - self, - batch_size: int = 1, - device: torch.device = torch.device("cpu"), - ) -> List[torch.Tensor]: - """ - Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] - is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - states[-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - states[-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - """ - states = self.encoder.get_init_states(batch_size, device) - - embed_states = self.encoder_embed.get_init_states(batch_size, device) - states.append(embed_states) - - processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) - states.append(processed_lens) - - return states - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - # if torch.cuda.is_available(): - # device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.eval() - - if params.jit is True: - convert_scaled_to_non_scaled(model, inplace=True) - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - - # Wrap encoder and encoder_embed as a module - if params.causal: - model.encoder = StreamingEncoderModel(model.encoder, model.encoder_embed) - chunk_size = model.encoder.chunk_size - left_context_len = model.encoder.left_context_len - filename = f"jit_script_chunk_{chunk_size}_left_{left_context_len}.pt" - else: - model.encoder = EncoderModel(model.encoder, model.encoder_embed) - filename = "jit_script.pt" - - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - model.save(str(params.exp_dir / filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torchscript. Export model.state_dict()") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/iwslt22_ta/ASR/zipformer/export.py b/egs/iwslt22_ta/ASR/zipformer/export.py new file mode 120000 index 000000000..14fd0531d --- /dev/null +++ b/egs/iwslt22_ta/ASR/zipformer/export.py @@ -0,0 +1 @@ +../../ST/zipformer/export.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/zipformer/generate_averaged_model.py b/egs/iwslt22_ta/ASR/zipformer/generate_averaged_model.py deleted file mode 100755 index fe29355f2..000000000 --- a/egs/iwslt22_ta/ASR/zipformer/generate_averaged_model.py +++ /dev/null @@ -1,202 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Yifan Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) use the checkpoint exp_dir/epoch-xxx.pt -./zipformer/generate_averaged_model.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp - -It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`. -You can later load it by `torch.load("epoch-28-avg-15.pt")`. - -(2) use the checkpoint exp_dir/checkpoint-iter.pt -./zipformer/generate_averaged_model.py \ - --iter 22000 \ - --avg 5 \ - --exp-dir ./zipformer/exp - -It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`. -You can later load it by `torch.load("iter-22000-avg-5.pt")`. -""" - - -import argparse -from pathlib import Path - -import sentencepiece as spm -import torch -from asr_datamodule import LibriSpeechAsrDataModule - -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints_with_averaged_model, - find_checkpoints, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -@torch.no_grad() -def main(): - parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - print("Script started") - - device = torch.device("cpu") - print(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - print("About to create model") - model = get_transducer_model(params) - - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - print( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt" - torch.save({"model": model.state_dict()}, filename) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - print( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt" - torch.save({"model": model.state_dict()}, filename) - - num_param = sum([p.numel() for p in model.parameters()]) - print(f"Number of model parameters: {num_param}") - - print("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/iwslt22_ta/ASR/zipformer/generate_averaged_model.py b/egs/iwslt22_ta/ASR/zipformer/generate_averaged_model.py new file mode 120000 index 000000000..edee92fe2 --- /dev/null +++ b/egs/iwslt22_ta/ASR/zipformer/generate_averaged_model.py @@ -0,0 +1 @@ +../../ST/zipformer/generate_averaged_model.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/zipformer/jit_pretrained.py b/egs/iwslt22_ta/ASR/zipformer/jit_pretrained.py deleted file mode 100755 index 4092d165e..000000000 --- a/egs/iwslt22_ta/ASR/zipformer/jit_pretrained.py +++ /dev/null @@ -1,272 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads torchscript models, exported by `torch.jit.script()` -and uses them to decode waves. -You can use the following command to get the exported models: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 30 \ - --avg 9 \ - --jit 1 - -Usage of this script: - -./zipformer/jit_pretrained.py \ - --nn-model-filename ./zipformer/exp/cpu_jit.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - /path/to/foo.wav \ - /path/to/bar.wav -""" - -import argparse -import logging -import math -from typing import List - -import kaldifeat -import sentencepiece as spm -import torch -import torchaudio -from torch.nn.utils.rnn import pad_sequence - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--nn-model-filename", - type=str, - required=True, - help="Path to the torchscript model cpu_jit.pt", - ) - - parser.add_argument( - "--bpe-model", - type=str, - help="""Path to bpe.model.""", - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float = 16000 -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -def greedy_search( - model: torch.jit.ScriptModule, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, -) -> List[List[int]]: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - Args: - model: - The transducer model. - encoder_out: - A 3-D tensor of shape (N, T, C) - encoder_out_lens: - A 1-D tensor of shape (N,). - Returns: - Return the decoded results for each utterance. - """ - assert encoder_out.ndim == 3 - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - device = encoder_out.device - blank_id = 0 # hard-code to 0 - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - context_size = model.decoder.context_size - hyps = [[blank_id] * context_size for _ in range(N)] - - decoder_input = torch.tensor( - hyps, - device=device, - dtype=torch.int64, - ) # (N, context_size) - - decoder_out = model.decoder( - decoder_input, - need_pad=torch.tensor([False]), - ).squeeze(1) - - offset = 0 - for batch_size in batch_size_list: - start = offset - end = offset + batch_size - current_encoder_out = packed_encoder_out.data[start:end] - current_encoder_out = current_encoder_out - # current_encoder_out's shape: (batch_size, encoder_out_dim) - offset = end - - decoder_out = decoder_out[:batch_size] - - logits = model.joiner( - current_encoder_out, - decoder_out, - ) - # logits'shape (batch_size, vocab_size) - - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v != blank_id: - hyps[i].append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = [h[-context_size:] for h in hyps[:batch_size]] - decoder_input = torch.tensor( - decoder_input, - device=device, - dtype=torch.int64, - ) - decoder_out = model.decoder( - decoder_input, - need_pad=torch.tensor([False]), - ) - decoder_out = decoder_out.squeeze(1) - - sorted_ans = [h[context_size:] for h in hyps] - ans = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - model = torch.jit.load(args.nn_model_filename) - - model.eval() - - model.to(device) - - sp = spm.SentencePieceProcessor() - sp.load(args.bpe_model) - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = 16000 - opts.mel_opts.num_bins = 80 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {args.sound_files}") - waves = read_sound_files( - filenames=args.sound_files, - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence( - features, - batch_first=True, - padding_value=math.log(1e-10), - ) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - encoder_out, encoder_out_lens = model.encoder( - features=features, - feature_lengths=feature_lengths, - ) - - hyps = greedy_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - s = "\n" - for filename, hyp in zip(args.sound_files, hyps): - words = sp.decode(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/iwslt22_ta/ASR/zipformer/jit_pretrained.py b/egs/iwslt22_ta/ASR/zipformer/jit_pretrained.py new file mode 120000 index 000000000..0cda00f55 --- /dev/null +++ b/egs/iwslt22_ta/ASR/zipformer/jit_pretrained.py @@ -0,0 +1 @@ +../../ST/zipformer/jit_pretrained.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/zipformer/jit_pretrained_streaming.py b/egs/iwslt22_ta/ASR/zipformer/jit_pretrained_streaming.py deleted file mode 100755 index 58d736685..000000000 --- a/egs/iwslt22_ta/ASR/zipformer/jit_pretrained_streaming.py +++ /dev/null @@ -1,269 +0,0 @@ -#!/usr/bin/env python3 -# flake8: noqa -# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads torchscript models exported by `torch.jit.script()` -and uses them to decode waves. -You can use the following command to get the exported models: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 30 \ - --avg 9 \ - --jit 1 - -Usage of this script: - -./zipformer/jit_pretrained_streaming.py \ - --nn-model-filename ./zipformer/exp-causal/jit_script_chunk_16_left_128.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - /path/to/foo.wav \ -""" - -import argparse -import logging -import math -from typing import List, Optional - -import kaldifeat -import sentencepiece as spm -import torch -import torchaudio -from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature -from torch.nn.utils.rnn import pad_sequence - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--nn-model-filename", - type=str, - required=True, - help="Path to the torchscript model cpu_jit.pt", - ) - - parser.add_argument( - "--bpe-model", - type=str, - help="""Path to bpe.model.""", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "sound_file", - type=str, - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -def greedy_search( - decoder: torch.jit.ScriptModule, - joiner: torch.jit.ScriptModule, - encoder_out: torch.Tensor, - decoder_out: Optional[torch.Tensor] = None, - hyp: Optional[List[int]] = None, - device: torch.device = torch.device("cpu"), -): - assert encoder_out.ndim == 2 - context_size = 2 - blank_id = 0 - - if decoder_out is None: - assert hyp is None, hyp - hyp = [blank_id] * context_size - decoder_input = torch.tensor(hyp, dtype=torch.int32, device=device).unsqueeze(0) - # decoder_input.shape (1,, 1 context_size) - decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1) - else: - assert decoder_out.ndim == 2 - assert hyp is not None, hyp - - T = encoder_out.size(0) - for i in range(T): - cur_encoder_out = encoder_out[i : i + 1] - joiner_out = joiner(cur_encoder_out, decoder_out).squeeze(0) - y = joiner_out.argmax(dim=0).item() - - if y != blank_id: - hyp.append(y) - decoder_input = hyp[-context_size:] - - decoder_input = torch.tensor( - decoder_input, dtype=torch.int32, device=device - ).unsqueeze(0) - decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1) - - return hyp, decoder_out - - -def create_streaming_feature_extractor(sample_rate) -> OnlineFeature: - """Create a CPU streaming feature extractor. - - At present, we assume it returns a fbank feature extractor with - fixed options. In the future, we will support passing in the options - from outside. - - Returns: - Return a CPU streaming feature extractor. - """ - opts = FbankOptions() - opts.device = "cpu" - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = sample_rate - opts.mel_opts.num_bins = 80 - return OnlineFbank(opts) - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - model = torch.jit.load(args.nn_model_filename) - model.eval() - model.to(device) - - encoder = model.encoder - decoder = model.decoder - joiner = model.joiner - - sp = spm.SentencePieceProcessor() - sp.load(args.bpe_model) - - logging.info("Constructing Fbank computer") - online_fbank = create_streaming_feature_extractor(args.sample_rate) - - logging.info(f"Reading sound files: {args.sound_file}") - wave_samples = read_sound_files( - filenames=[args.sound_file], - expected_sample_rate=args.sample_rate, - )[0] - logging.info(wave_samples.shape) - - logging.info("Decoding started") - - chunk_length = encoder.chunk_size * 2 - T = chunk_length + encoder.pad_length - - logging.info(f"chunk_length: {chunk_length}") - logging.info(f"T: {T}") - - states = encoder.get_init_states(device=device) - - tail_padding = torch.zeros(int(0.3 * args.sample_rate), dtype=torch.float32) - - wave_samples = torch.cat([wave_samples, tail_padding]) - - chunk = int(0.25 * args.sample_rate) # 0.2 second - num_processed_frames = 0 - - hyp = None - decoder_out = None - - start = 0 - while start < wave_samples.numel(): - logging.info(f"{start}/{wave_samples.numel()}") - end = min(start + chunk, wave_samples.numel()) - samples = wave_samples[start:end] - start += chunk - online_fbank.accept_waveform( - sampling_rate=args.sample_rate, - waveform=samples, - ) - while online_fbank.num_frames_ready - num_processed_frames >= T: - frames = [] - for i in range(T): - frames.append(online_fbank.get_frame(num_processed_frames + i)) - frames = torch.cat(frames, dim=0).to(device).unsqueeze(0) - x_lens = torch.tensor([T], dtype=torch.int32, device=device) - encoder_out, out_lens, states = encoder( - features=frames, - feature_lengths=x_lens, - states=states, - ) - num_processed_frames += chunk_length - - hyp, decoder_out = greedy_search( - decoder, joiner, encoder_out.squeeze(0), decoder_out, hyp, device=device - ) - - context_size = 2 - logging.info(args.sound_file) - logging.info(sp.decode(hyp[context_size:])) - - logging.info("Decoding Done") - - -torch.set_num_threads(4) -torch.set_num_interop_threads(1) -torch._C._jit_set_profiling_executor(False) -torch._C._jit_set_profiling_mode(False) -torch._C._set_graph_executor_optimize(False) -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/iwslt22_ta/ASR/zipformer/jit_pretrained_streaming.py b/egs/iwslt22_ta/ASR/zipformer/jit_pretrained_streaming.py new file mode 120000 index 000000000..5c5961a4b --- /dev/null +++ b/egs/iwslt22_ta/ASR/zipformer/jit_pretrained_streaming.py @@ -0,0 +1 @@ +../../ST/zipformer/jit_pretrained_streaming.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/zipformer/joiner.py b/egs/iwslt22_ta/ASR/zipformer/joiner.py deleted file mode 100644 index f03cc930e..000000000 --- a/egs/iwslt22_ta/ASR/zipformer/joiner.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn as nn -from scaling import ScaledLinear - - -class Joiner(nn.Module): - def __init__( - self, - encoder_dim: int, - decoder_dim: int, - joiner_dim: int, - vocab_size: int, - ): - super().__init__() - - self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25) - self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25) - self.output_linear = nn.Linear(joiner_dim, vocab_size) - - def forward( - self, - encoder_out: torch.Tensor, - decoder_out: torch.Tensor, - project_input: bool = True, - ) -> torch.Tensor: - """ - Args: - encoder_out: - Output from the encoder. Its shape is (N, T, s_range, C). - decoder_out: - Output from the decoder. Its shape is (N, T, s_range, C). - project_input: - If true, apply input projections encoder_proj and decoder_proj. - If this is false, it is the user's responsibility to do this - manually. - Returns: - Return a tensor of shape (N, T, s_range, C). - """ - assert encoder_out.ndim == decoder_out.ndim, (encoder_out.shape, decoder_out.shape) - - if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj( - decoder_out - ) - else: - logit = encoder_out + decoder_out - - logit = self.output_linear(torch.tanh(logit)) - - return logit diff --git a/egs/iwslt22_ta/ASR/zipformer/joiner.py b/egs/iwslt22_ta/ASR/zipformer/joiner.py new file mode 120000 index 000000000..5b8a36332 --- /dev/null +++ b/egs/iwslt22_ta/ASR/zipformer/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/zipformer/model.py b/egs/iwslt22_ta/ASR/zipformer/model.py deleted file mode 100644 index 44d1dca59..000000000 --- a/egs/iwslt22_ta/ASR/zipformer/model.py +++ /dev/null @@ -1,489 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import k2 -import torch -import torch.nn as nn -from encoder_interface import EncoderInterface - -from icefall.utils import add_sos, make_pad_mask -from scaling import ScaledLinear - - -class Transducer(nn.Module): - """It implements https://arxiv.org/pdf/1211.3711.pdf - "Sequence Transduction with Recurrent Neural Networks" - """ - - def __init__( - self, - encoder_embed: nn.Module, - encoder: EncoderInterface, - decoder: nn.Module, - joiner: nn.Module, - encoder_dim: int, - decoder_dim: int, - joiner_dim: int, - vocab_size: int, - ): - """ - Args: - encoder_embed: - It is a Convolutional 2D subsampling module. It converts - an input of shape (N, T, idim) to an output of of shape - (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2. - encoder: - It is the transcription network in the paper. Its accepts - two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). - It returns two tensors: `logits` of shape (N, T, encoder_dm) and - `logit_lens` of shape (N,). - decoder: - It is the prediction network in the paper. Its input shape - is (N, U) and its output shape is (N, U, decoder_dim). - It should contain one attribute: `blank_id`. - joiner: - It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). - Its output shape is (N, T, U, vocab_size). Note that its output contains - unnormalized probs, i.e., not processed by log-softmax. - """ - super().__init__() - assert isinstance(encoder, EncoderInterface), type(encoder) - assert hasattr(decoder, "blank_id") - - self.encoder_embed = encoder_embed - self.encoder = encoder - self.decoder = decoder - self.joiner = joiner - - self.simple_am_proj = ScaledLinear( - encoder_dim, - vocab_size, - initial_scale=0.25, - ) - self.simple_lm_proj = ScaledLinear( - decoder_dim, - vocab_size, - initial_scale=0.25, - ) - - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - y: k2.RaggedTensor, - prune_range: int = 5, - am_scale: float = 0.0, - lm_scale: float = 0.0, - ) -> torch.Tensor: - """ - Args: - x: - A 3-D tensor of shape (N, T, C). - x_lens: - A 1-D tensor of shape (N,). It contains the number of frames in `x` - before padding. - y: - A ragged tensor with 2 axes [utt][label]. It contains labels of each - utterance. - prune_range: - The prune range for rnnt loss, it means how many symbols(context) - we are considering for each frame to compute the loss. - am_scale: - The scale to smooth the loss with am (output of encoder network) - part - lm_scale: - The scale to smooth the loss with lm (output of predictor network) - part - Returns: - Return the transducer loss. - - Note: - Regarding am_scale & lm_scale, it will make the loss-function one of - the form: - lm_scale * lm_probs + am_scale * am_probs + - (1-lm_scale-am_scale) * combined_probs - """ - assert x.ndim == 3, x.shape - assert x_lens.ndim == 1, x_lens.shape - assert y.num_axes == 2, y.num_axes - - assert x.size(0) == x_lens.size(0) == y.dim0 - - # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") - x, x_lens = self.encoder_embed(x, x_lens) - # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M") - - src_key_padding_mask = make_pad_mask(x_lens) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - encoder_out, x_lens = self.encoder(x, x_lens, src_key_padding_mask) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - assert torch.all(x_lens > 0) - - # Now for the decoder, i.e., the prediction network - row_splits = y.shape.row_splits(1) - y_lens = row_splits[1:] - row_splits[:-1] - - blank_id = self.decoder.blank_id - sos_y = add_sos(y, sos_id=blank_id) - - # sos_y_padded: [B, S + 1], start with SOS. - sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) - - # decoder_out: [B, S + 1, decoder_dim] - decoder_out = self.decoder(sos_y_padded) - - # Note: y does not start with SOS - # y_padded : [B, S] - y_padded = y.pad(mode="constant", padding_value=0) - - y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (encoder_out.size(0), 4), - dtype=torch.int64, - device=encoder_out.device, - ) - boundary[:, 2] = y_lens - boundary[:, 3] = x_lens - - lm = self.simple_lm_proj(decoder_out) - am = self.simple_am_proj(encoder_out) - - # if self.training and random.random() < 0.25: - # lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04) - # if self.training and random.random() < 0.25: - # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - - with torch.cuda.amp.autocast(enabled=False): - simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( - lm=lm.float(), - am=am.float(), - symbols=y_padded, - termination_symbol=blank_id, - lm_only_scale=lm_scale, - am_only_scale=am_scale, - boundary=boundary, - reduction="sum", - return_grad=True, - ) - - # ranges : [B, T, prune_range] - ranges = k2.get_rnnt_prune_ranges( - px_grad=px_grad, - py_grad=py_grad, - boundary=boundary, - s_range=prune_range, - ) - - # am_pruned : [B, T, prune_range, encoder_dim] - # lm_pruned : [B, T, prune_range, decoder_dim] - am_pruned, lm_pruned = k2.do_rnnt_pruning( - am=self.joiner.encoder_proj(encoder_out), - lm=self.joiner.decoder_proj(decoder_out), - ranges=ranges, - ) - - # logits : [B, T, prune_range, vocab_size] - - # project_input=False since we applied the decoder's input projections - # prior to do_rnnt_pruning (this is an optimization for speed). - logits = self.joiner(am_pruned, lm_pruned, project_input=False) - - with torch.cuda.amp.autocast(enabled=False): - pruned_loss = k2.rnnt_loss_pruned( - logits=logits.float(), - symbols=y_padded, - ranges=ranges, - termination_symbol=blank_id, - boundary=boundary, - reduction="sum", - use_hat_loss=True, - ) - - return (simple_loss, pruned_loss) - - -class Transducer_asr_st(Transducer): - """ - "Sequence Transduction with Recurrent Neural Networks for multitask ASR and Speech Translation" - """ - - def __init__( - self, - encoder_embed: nn.Module, - encoder: EncoderInterface, - decoder: nn.Module, - decoder_tgt: nn.Module, - joiner: nn.Module, - joiner_tgt: nn.Module, - encoder_dim: int, - decoder_dim: int, - joiner_dim: int, - vocab_size: int, - ): - """ - Args: - encoder_embed: - It is a Convolutional 2D subsampling module. It converts - an input of shape (N, T, idim) to an output of of shape - (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2. - encoder: - It is the transcription network in the paper. Its accepts - two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). - It returns two tensors: `logits` of shape (N, T, encoder_dm) and - `logit_lens` of shape (N,). - decoder: - It is the prediction network in the paper. Its input shape - is (N, U) and its output shape is (N, U, decoder_dim). - It should contain one attribute: `blank_id`. - joiner: - It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). - Its output shape is (N, T, U, vocab_size). Note that its output contains - unnormalized probs, i.e., not processed by log-softmax. - """ - super().__init__() - assert isinstance(encoder, EncoderInterface), type(encoder) - assert hasattr(decoder, "blank_id") - - self.encoder_embed = encoder_embed - self.encoder = encoder - self.decoder = decoder - self.decoder_tgt = decoder_tgt - self.joiner = joiner - self.joiner_tgt = joiner_tgt - - self.simple_am_proj = ScaledLinear( - encoder_dim, - vocab_size, - initial_scale=0.25, - ) - - self.simple_am_proj_tgt = ScaledLinear( - encoder_dim, - vocab_size, - initial_scale=0.25, - ) - - self.simple_lm_proj = ScaledLinear( - decoder_dim, - vocab_size, - initial_scale=0.25, - ) - - self.simple_lm_proj_tgt = ScaledLinear( - decoder_dim, - vocab_size, - initial_scale=0.25, - ) - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - y: k2.RaggedTensor, - y_tgt: k2.RaggedTensor, - prune_range: int = 5, - prune_range_tgt: int = 5, - am_scale: float = 0.0, - lm_scale: float = 0.0, - ) -> torch.Tensor: - """ - Args: - x: - A 3-D tensor of shape (N, T, C). - x_lens: - A 1-D tensor of shape (N,). It contains the number of frames in `x` - before padding. - y: - A ragged tensor with 2 axes [utt][label]. It contains labels of each - utterance. - prune_range: - The prune range for rnnt loss, it means how many symbols(context) - we are considering for each frame to compute the loss. - am_scale: - The scale to smooth the loss with am (output of encoder network) - part - lm_scale: - The scale to smooth the loss with lm (output of predictor network) - part - Returns: - Return the transducer loss. - - Note: - Regarding am_scale & lm_scale, it will make the loss-function one of - the form: - lm_scale * lm_probs + am_scale * am_probs + - (1-lm_scale-am_scale) * combined_probs - """ - assert x.ndim == 3, x.shape - assert x_lens.ndim == 1, x_lens.shape - assert y.num_axes == 2, y.num_axes - - assert x.size(0) == x_lens.size(0) == y.dim0 - - # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") - x, x_lens = self.encoder_embed(x, x_lens) - # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M") - - src_key_padding_mask = make_pad_mask(x_lens) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - encoder_out, x_lens = self.encoder(x, x_lens, src_key_padding_mask) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - assert torch.all(x_lens > 0) - - # Now for the decoder, i.e., the prediction network - row_splits = y.shape.row_splits(1) - y_lens = row_splits[1:] - row_splits[:-1] - - row_splits_tgt = y_tgt.shape.row_splits(1) - y_lens_tgt = row_splits_tgt[1:] - row_splits_tgt[:-1] - - blank_id = self.decoder.blank_id - sos_y = add_sos(y, sos_id=blank_id) - - blank_id_tgt = self.decoder_tgt.blank_id - sos_y_tgt = add_sos(y_tgt, sos_id=blank_id_tgt) - # sos_y_padded: [B, S + 1], start with SOS. - sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) - sos_y_padded_tgt = sos_y_tgt.pad(mode="constant", padding_value=blank_id_tgt) - - # decoder_out: [B, S + 1, decoder_dim] - decoder_out = self.decoder(sos_y_padded) - decoder_out_tgt = self.decoder_tgt(sos_y_padded_tgt) - - # Note: y does not start with SOS - # y_padded : [B, S] - y_padded = y.pad(mode="constant", padding_value=0) - - y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (encoder_out.size(0), 4), - dtype=torch.int64, - device=encoder_out.device, - ) - boundary[:, 2] = y_lens - boundary[:, 3] = x_lens - - lm = self.simple_lm_proj(decoder_out) - am = self.simple_am_proj(encoder_out) - - # tgt - y_padded_tgt = y_tgt.pad(mode="constant", padding_value=0) - - y_padded_tgt = y_padded_tgt.to(torch.int64) - boundary_tgt = torch.zeros( - (encoder_out.size(0), 4), - dtype=torch.int64, - device=encoder_out.device, - ) - boundary_tgt[:, 2] = y_lens_tgt - boundary_tgt[:, 3] = x_lens - - lm_tgt = self.simple_lm_proj_tgt(decoder_out_tgt) - am_tgt = self.simple_am_proj_tgt(encoder_out) - - # if self.training and random.random() < 0.25: - # lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04) - # if self.training and random.random() < 0.25: - # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - - with torch.cuda.amp.autocast(enabled=False): - simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( - lm=lm.float(), - am=am.float(), - symbols=y_padded, - termination_symbol=blank_id, - lm_only_scale=lm_scale, - am_only_scale=am_scale, - boundary=boundary, - reduction="sum", - return_grad=True, - ) - - with torch.cuda.amp.autocast(enabled=False): - simple_loss_tgt, (px_grad_tgt, py_grad_tgt) = k2.rnnt_loss_smoothed( - lm=lm.float(), - am=am.float(), - symbols=y_padded_tgt, - termination_symbol=blank_id_tgt, - lm_only_scale=lm_scale, - am_only_scale=am_scale, - boundary=boundary_tgt, - reduction="sum", - return_grad=True, - ) - - # ranges : [B, T, prune_range] - ranges = k2.get_rnnt_prune_ranges( - px_grad=px_grad, - py_grad=py_grad, - boundary=boundary, - s_range=prune_range, - ) - - ranges_tgt = k2.get_rnnt_prune_ranges( - px_grad=px_grad_tgt, - py_grad=py_grad_tgt, - boundary=boundary_tgt, - s_range=prune_range_tgt, - ) - - # am_pruned : [B, T, prune_range, encoder_dim] - # lm_pruned : [B, T, prune_range, decoder_dim] - am_pruned, lm_pruned = k2.do_rnnt_pruning( - am=self.joiner.encoder_proj(encoder_out), - lm=self.joiner.decoder_proj(decoder_out), - ranges=ranges, - ) - - am_pruned_tgt, lm_pruned_tgt = k2.do_rnnt_pruning( - am=self.joiner_tgt.encoder_proj(encoder_out), - lm=self.joiner_tgt.decoder_proj(decoder_out), - ranges=ranges_tgt, - ) - - # logits : [B, T, prune_range, vocab_size] - - # project_input=False since we applied the decoder's input projections - # prior to do_rnnt_pruning (this is an optimization for speed). - logits = self.joiner(am_pruned, lm_pruned, project_input=False) - logits_tgt = self.joiner_tgt(am_pruned_tgt, lm_pruned_tgt, project_input=False) - - with torch.cuda.amp.autocast(enabled=False): - pruned_loss = k2.rnnt_loss_pruned( - logits=logits.float(), - symbols=y_padded, - ranges=ranges, - termination_symbol=blank_id, - boundary=boundary, - reduction="sum", - ) - - pruned_loss_tgt = k2.rnnt_loss_pruned( - logits=logits_tgt.float(), - symbols=y_padded_tgt, - ranges=ranges_tgt, - termination_symbol=blank_id_tgt, - boundary=boundary_tgt, - reduction="sum", - ) - - return (simple_loss, pruned_loss, simple_loss_tgt, pruned_loss_tgt) diff --git a/egs/iwslt22_ta/ASR/zipformer/model.py b/egs/iwslt22_ta/ASR/zipformer/model.py new file mode 120000 index 000000000..dbf1ff29b --- /dev/null +++ b/egs/iwslt22_ta/ASR/zipformer/model.py @@ -0,0 +1 @@ +../../ST/zipformer/model.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/zipformer/optim.py b/egs/iwslt22_ta/ASR/zipformer/optim.py deleted file mode 100644 index abfb2092c..000000000 --- a/egs/iwslt22_ta/ASR/zipformer/optim.py +++ /dev/null @@ -1,1173 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) -# -# See ../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import contextlib -import logging -import random -from collections import defaultdict -from typing import Dict, List, Optional, Tuple, Union - -import torch -from lhotse.utils import fix_random_seed -from torch import Tensor -from torch.optim import Optimizer - - -class BatchedOptimizer(Optimizer): - """ - This class adds to class Optimizer the capability to optimize parameters in batches: - it will stack the parameters and their grads for you so the optimizer can work - on tensors with an extra leading dimension. This is intended for speed with GPUs, - as it reduces the number of kernels launched in the optimizer. - - Args: - params: - """ - - def __init__(self, params, defaults): - super(BatchedOptimizer, self).__init__(params, defaults) - - @contextlib.contextmanager - def batched_params(self, param_group, group_params_names): - """ - This function returns (technically, yields) a list of - of tuples (p, state), where - p is a `fake` parameter that is stacked (over axis 0) from real parameters - that share the same shape, and its gradient is also stacked; - `state` is the state corresponding to this batch of parameters - (it will be physically located in the "state" for one of the real - parameters, the last one that has any particular shape and dtype). - - This function is decorated as a context manager so that it can - write parameters back to their "real" locations. - - The idea is, instead of doing: - - for p in group["params"]: - state = self.state[p] - ... - - you can do: - - with self.batched_params(group["params"]) as batches: - for p, state, p_names in batches: - ... - - - Args: - group: a parameter group, which is a list of parameters; should be - one of self.param_groups. - group_params_names: name for each parameter in group, - which is List[str]. - """ - batches = defaultdict( - list - ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter - batches_names = defaultdict( - list - ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str - - assert len(param_group) == len(group_params_names) - for p, named_p in zip(param_group, group_params_names): - key = (str(p.dtype), *p.shape) - batches[key].append(p) - batches_names[key].append(named_p) - - batches_names_keys = list(batches_names.keys()) - sorted_idx = sorted( - range(len(batches_names)), key=lambda i: batches_names_keys[i] - ) - batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx] - batches = [batches[batches_names_keys[idx]] for idx in sorted_idx] - - stacked_params_dict = dict() - - # turn batches into a list, in deterministic order. - # tuples will contain tuples of (stacked_param, state, stacked_params_names), - # one for each batch in `batches`. - tuples = [] - - for batch, batch_names in zip(batches, batches_names): - p = batch[0] - # we arbitrarily store the state in the - # state corresponding to the 1st parameter in the - # group. class Optimizer will take care of saving/loading state. - state = self.state[p] - p_stacked = torch.stack(batch) - grad = torch.stack( - [torch.zeros_like(p) if p.grad is None else p.grad for p in batch] - ) - p_stacked.grad = grad - stacked_params_dict[key] = p_stacked - tuples.append((p_stacked, state, batch_names)) - - yield tuples # <-- calling code will do the actual optimization here! - - for ((stacked_params, _state, _names), batch) in zip(tuples, batches): - for i, p in enumerate(batch): # batch is list of Parameter - p.copy_(stacked_params[i]) - - -class ScaledAdam(BatchedOptimizer): - """ - Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update - proportional to the norm of that parameter; and also learn the scale of the parameter, - in log space, subject to upper and lower limits (as if we had factored each parameter as - param = underlying_param * log_scale.exp()) - - - Args: - params: The parameters or param_groups to optimize (like other Optimizer subclasses) - Unlike common optimizers, which accept model.parameters() or groups of parameters(), - this optimizer could accept model.named_parameters() or groups of named_parameters(). - See comments of function _get_names_of_parameters for its 4 possible cases. - lr: The learning rate. We will typically use a learning rate schedule that starts - at 0.03 and decreases over time, i.e. much higher than other common - optimizers. - clipping_scale: (e.g. 2.0) - A scale for gradient-clipping: if specified, the normalized gradients - over the whole model will be clipped to have 2-norm equal to - `clipping_scale` times the median 2-norm over the most recent period - of `clipping_update_period` minibatches. By "normalized gradients", - we mean after multiplying by the rms parameter value for this tensor - [for non-scalars]; this is appropriate because our update is scaled - by this quantity. - betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. - Must satisfy 0 < beta <= beta2 < 1. - scalar_lr_scale: A scaling factor on the learning rate, that we use to update the - scale of each parameter tensor and scalar parameters of the mode.. - If each parameter were decomposed - as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale - would be a the scaling factor on the learning rate of p_scale. - eps: A general-purpose epsilon to prevent division by zero - param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be >= this value) - param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be <= this value) - scalar_max: Maximum absolute value for scalar parameters (applicable if your - model has any parameters with numel() == 1). - size_update_period: The periodicity, in steps, with which we update the size (scale) - of the parameter tensor. This is provided to save a little time - in the update. - clipping_update_period: if clipping_scale is specified, this is the period - """ - - def __init__( - self, - params, - lr=3e-02, - clipping_scale=None, - betas=(0.9, 0.98), - scalar_lr_scale=0.1, - eps=1.0e-08, - param_min_rms=1.0e-05, - param_max_rms=3.0, - scalar_max=10.0, - size_update_period=4, - clipping_update_period=100, - ): - - defaults = dict( - lr=lr, - clipping_scale=clipping_scale, - betas=betas, - scalar_lr_scale=scalar_lr_scale, - eps=eps, - param_min_rms=param_min_rms, - param_max_rms=param_max_rms, - scalar_max=scalar_max, - size_update_period=size_update_period, - clipping_update_period=clipping_update_period, - ) - - # If params only contains parameters or group of parameters, - # i.e when parameter names are not given, - # this flag will be set to False in funciton _get_names_of_parameters. - self.show_dominant_parameters = True - param_groups, parameters_names = self._get_names_of_parameters(params) - super(ScaledAdam, self).__init__(param_groups, defaults) - assert len(self.param_groups) == len(parameters_names) - self.parameters_names = parameters_names - - def _get_names_of_parameters( - self, params_or_named_params - ) -> Tuple[List[Dict], List[List[str]]]: - """ - Args: - params_or_named_params: according to the way ScaledAdam is initialized in train.py, - this argument could be one of following 4 cases, - case 1, a generator of parameter, e.g.: - optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=3.0) - - case 2, a list of parameter groups with different config, e.g.: - model_param_groups = [ - {'params': model.encoder.parameters(), 'lr': 0.05}, - {'params': model.decoder.parameters(), 'lr': 0.01}, - {'params': model.joiner.parameters(), 'lr': 0.03}, - ] - optimizer = ScaledAdam(model_param_groups, lr=params.base_lr, clipping_scale=3.0) - - case 3, a generator of named_parameter, e.g.: - optimizer = ScaledAdam(model.named_parameters(), lr=params.base_lr, clipping_scale=3.0) - - case 4, a list of named_parameter groups with different config, e.g.: - model_named_param_groups = [ - {'named_params': model.encoder.named_parameters(), 'lr': 0.05}, - {'named_params': model.decoder.named_parameters(), 'lr': 0.01}, - {'named_params': model.joiner.named_parameters(), 'lr': 0.03}, - ] - optimizer = ScaledAdam(model_named_param_groups, lr=params.base_lr, clipping_scale=3.0) - - For case 1 and case 2, input params is used to initialize the underlying torch.optimizer. - For case 3 and case 4, firstly, names and params are extracted from input named_params, - then, these extracted params are used to initialize the underlying torch.optimizer, - and these extracted names are mainly used by function - `_show_gradient_dominating_parameter` - - Returns: - Returns a tuple containing 2 elements: - - `param_groups` with type List[Dict], each Dict element is a parameter group. - An example of `param_groups` could be: - [ - {'params': `one iterable of Parameter`, 'lr': 0.05}, - {'params': `another iterable of Parameter`, 'lr': 0.08}, - {'params': `a third iterable of Parameter`, 'lr': 0.1}, - ] - - `param_gruops_names` with type List[List[str]], - each `List[str]` is for a group['params'] in param_groups, - and each `str` is the name of a parameter. - A dummy name "foo" is related to each parameter, - if input are params without names, i.e. case 1 or case 2. - """ - # variable naming convention in this function: - # p is short for param. - # np is short for named_param. - # p_or_np is short for param_or_named_param. - # cur is short for current. - # group is a dict, e.g. {'params': iterable of parameter, 'lr': 0.05, other fields}. - # groups is a List[group] - - iterable_or_groups = list(params_or_named_params) - if len(iterable_or_groups) == 0: - raise ValueError("optimizer got an empty parameter list") - - # The first value of returned tuple. A list of dicts containing at - # least 'params' as a key. - param_groups = [] - - # The second value of returned tuple, - # a List[List[str]], each sub-List is for a group. - param_groups_names = [] - - if not isinstance(iterable_or_groups[0], dict): - # case 1 or case 3, - # the input is an iterable of parameter or named parameter. - param_iterable_cur_group = [] - param_names_cur_group = [] - for p_or_np in iterable_or_groups: - if isinstance(p_or_np, tuple): - # case 3 - name, param = p_or_np - else: - # case 1 - assert isinstance(p_or_np, torch.Tensor) - param = p_or_np - # Assign a dummy name as a placeholder - name = "foo" - self.show_dominant_parameters = False - param_iterable_cur_group.append(param) - param_names_cur_group.append(name) - param_groups.append({"params": param_iterable_cur_group}) - param_groups_names.append(param_names_cur_group) - else: - # case 2 or case 4 - # the input is groups of parameter or named parameter. - for cur_group in iterable_or_groups: - assert "named_params" in cur_group - name_list = [ x[0] for x in cur_group["named_params"] ] - p_list = [ x[1] for x in cur_group["named_params"] ] - del cur_group["named_params"] - cur_group["params"] = p_list - param_groups.append(cur_group) - param_groups_names.append(name_list) - - return param_groups, param_groups_names - - def __setstate__(self, state): - super(ScaledAdam, self).__setstate__(state) - - @torch.no_grad() - def step(self, closure=None): - """Performs a single optimization step. - - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - batch = True - - for group, group_params_names in zip(self.param_groups, self.parameters_names): - - with self.batched_params(group["params"], group_params_names) as batches: - - # batches is list of pairs (stacked_param, state). stacked_param is like - # a regular parameter, and will have a .grad, but the 1st dim corresponds to - # a stacking dim, it is not a real dim. - - if ( - len(batches[0][1]) == 0 - ): # if len(first state) == 0: not yet initialized - clipping_scale = 1 - else: - clipping_scale = self._get_clipping_scale(group, batches) - - for p, state, _ in batches: - # Perform optimization step. - # grad is not going to be None, we handled that when creating the batches. - grad = p.grad - if grad.is_sparse: - raise RuntimeError( - "ScaledAdam optimizer does not support sparse gradients" - ) - # State initialization - if len(state) == 0: - self._init_state(group, p, state) - - self._step_one_batch(group, p, state, clipping_scale) - - return loss - - def _init_state(self, group: dict, p: Tensor, state: dict): - """ - Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p - is actually the batch dimension, corresponding to batched-together - parameters of a given shape. - - - Args: - group: Dict to look up configuration values. - p: The parameter that we are initializing the state for - state: Dict from string to whatever state we are initializing - """ - size_update_period = group["size_update_period"] - - state["step"] = 0 - - kwargs = {"device": p.device, "dtype": p.dtype} - - # 'delta' implements conventional momentum. There are - # several different kinds of update going on, so rather than - # compute "exp_avg" like in Adam, we store and decay a - # parameter-change "delta", which combines all forms of - # update. this is equivalent to how it's done in Adam, - # except for the first few steps. - state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format) - - batch_size = p.shape[0] - numel = p.numel() // batch_size - - if numel > 1: - # "param_rms" just periodically records the scalar root-mean-square value of - # the parameter tensor. - # it has a shape like (batch_size, 1, 1, 1, 1) - param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() - state["param_rms"] = param_rms - - state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) - state["scale_grads"] = torch.zeros( - size_update_period, *param_rms.shape, **kwargs - ) - - # exp_avg_sq is the weighted sum of scaled gradients. as in Adam. - state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) - - def _get_clipping_scale( - self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]] - ) -> float: - """ - Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients - by this amount before applying the rest of the update. - - Args: - group: the parameter group, an item in self.param_groups - tuples: a list of tuples of (param, state, param_names) - where param is a batched set of parameters, - with a .grad (1st dim is batch dim) - and state is the state-dict where optimization parameters are kept. - param_names is a List[str] while each str is name for a parameter - in batched set of parameters "param". - """ - assert len(tuples) >= 1 - clipping_scale = group["clipping_scale"] - (first_p, first_state, _) = tuples[0] - step = first_state["step"] - if clipping_scale is None or step == 0: - # no clipping. return early on step == 0 because the other - # parameters' state won't have been initialized yet. - return 1.0 - clipping_update_period = group["clipping_update_period"] - - tot_sumsq = torch.tensor(0.0, device=first_p.device) - for (p, state, param_names) in tuples: - grad = p.grad - if grad.is_sparse: - raise RuntimeError( - "ScaledAdam optimizer does not support sparse gradients" - ) - if p.numel() == p.shape[0]: # a batch of scalars - tot_sumsq += (grad**2).sum() # sum() to change shape [1] to [] - else: - tot_sumsq += ((grad * state["param_rms"]) ** 2).sum() - - tot_norm = tot_sumsq.sqrt() - if "model_norms" not in first_state: - first_state["model_norms"] = torch.zeros( - clipping_update_period, device=p.device - ) - first_state["model_norms"][step % clipping_update_period] = tot_norm - - if step % clipping_update_period == 0: - # Print some stats. - # We don't reach here if step == 0 because we would have returned - # above. - sorted_norms = first_state["model_norms"].sort()[0].to("cpu") - quartiles = [] - for n in range(0, 5): - index = min( - clipping_update_period - 1, (clipping_update_period // 4) * n - ) - quartiles.append(sorted_norms[index].item()) - - median = quartiles[2] - threshold = clipping_scale * median - first_state["model_norm_threshold"] = threshold - percent_clipped = ( - first_state["num_clipped"] * 100.0 / clipping_update_period - if "num_clipped" in first_state - else 0.0 - ) - first_state["num_clipped"] = 0 - quartiles = " ".join(["%.3e" % x for x in quartiles]) - logging.info( - f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " - f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}" - ) - - if step < clipping_update_period: - return 1.0 # We have not yet estimated a norm to clip to. - else: - try: - model_norm_threshold = first_state["model_norm_threshold"] - except KeyError: - logging.info( - "Warning: model_norm_threshold not in state: possibly " - "you changed config when restarting, adding clipping_scale option?" - ) - return 1.0 - ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item()) - if ans < 1.0: - first_state["num_clipped"] += 1 - if ans < 0.1: - logging.warn( - f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}" - ) - if self.show_dominant_parameters: - assert p.shape[0] == len(param_names) - self._show_gradient_dominating_parameter(tuples, tot_sumsq) - return ans - - def _show_gradient_dominating_parameter( - self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor - ): - """ - Show information of parameter which dominates tot_sumsq. - - Args: - tuples: a list of tuples of (param, state, param_names) - where param is a batched set of parameters, - with a .grad (1st dim is batch dim) - and state is the state-dict where optimization parameters are kept. - param_names is a List[str] while each str is name for a parameter - in batched set of parameters "param". - tot_sumsq: sumsq of all parameters. Though it's could be calculated - from tuples, we still pass it to save some time. - """ - all_sumsq_orig = {} - for (p, state, batch_param_names) in tuples: - # p is a stacked batch parameters. - batch_grad = p.grad - if p.numel() == p.shape[0]: # a batch of scalars - batch_sumsq_orig = batch_grad**2 - # Dummy values used by following `zip` statement. - batch_rms_orig = torch.ones(p.shape[0]) - else: - batch_rms_orig = state["param_rms"] - batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum( - dim=list(range(1, batch_grad.ndim)) - ) - - for name, sumsq_orig, rms, grad in zip( - batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad - ): - - proportion_orig = sumsq_orig / tot_sumsq - all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad) - - assert torch.isclose( - sum([value[0] for value in all_sumsq_orig.values()]).cpu(), - torch.tensor(1.0), - ) - sorted_by_proportion = { - k: v - for k, v in sorted( - all_sumsq_orig.items(), key=lambda item: item[1][0], reverse=True - ) - } - dominant_param_name = next(iter(sorted_by_proportion)) - ( - dominant_proportion, - dominant_sumsq, - dominant_rms, - dominant_grad, - ) = sorted_by_proportion[dominant_param_name] - logging.info( - f"Parameter dominating tot_sumsq {dominant_param_name}" - f" with proportion {dominant_proportion:.2f}," - f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)" - f"={dominant_sumsq:.3e}," - f" grad_sumsq={(dominant_grad**2).sum():.3e}," - f" orig_rms_sq={(dominant_rms**2).item():.3e}" - ) - - def _step_one_batch( - self, group: dict, p: Tensor, state: dict, clipping_scale: float - ): - """ - Do the step for one parameter, which is actually going to be a batch of - `real` parameters, with dim 0 as the batch dim. - Args: - group: dict to look up configuration values - p: parameter to update (actually multiple parameters stacked together - as a batch) - state: state-dict for p, to look up the optimizer state - """ - lr = group["lr"] - size_update_period = group["size_update_period"] - beta1 = group["betas"][0] - - grad = p.grad - if clipping_scale != 1.0: - grad = grad * clipping_scale - step = state["step"] - delta = state["delta"] - - delta.mul_(beta1) - batch_size = p.shape[0] - numel = p.numel() // batch_size - if numel > 1: - # Update the size/scale of p, and set param_rms - scale_grads = state["scale_grads"] - scale_grads[step % size_update_period] = (p * grad).sum( - dim=list(range(1, p.ndim)), keepdim=True - ) - if step % size_update_period == size_update_period - 1: - param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..) - param_rms.copy_( - (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() - ) - if step > 0: - # self._size_update() learns the overall scale on the - # parameter, by shrinking or expanding it. - self._size_update(group, scale_grads, p, state) - - if numel == 1: - # For parameters with 1 element we just use regular Adam. - # Updates delta. - self._step_scalar(group, p, state) - else: - self._step(group, p, state) - - state["step"] = step + 1 - - def _size_update( - self, group: dict, scale_grads: Tensor, p: Tensor, state: dict - ) -> None: - """ - Called only where p.numel() > 1, this updates the scale of the parameter. - If we imagine: p = underlying_param * scale.exp(), and we are doing - gradient descent on underlying param and on scale, this function does the update - on `scale`. - - Args: - group: dict to look up configuration values - scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing - grads w.r.t. the scales. - p: The parameter to update - state: The state-dict of p - """ - - param_rms = state["param_rms"] - beta1, beta2 = group["betas"] - size_lr = group["lr"] * group["scalar_lr_scale"] - param_min_rms = group["param_min_rms"] - param_max_rms = group["param_max_rms"] - eps = group["eps"] - step = state["step"] - batch_size = p.shape[0] - - size_update_period = scale_grads.shape[0] - # correct beta2 for the size update period: we will have - # faster decay at this level. - beta2_corr = beta2**size_update_period - - scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..) - scale_exp_avg_sq.mul_(beta2_corr).add_( - (scale_grads**2).mean(dim=0), # mean over dim `size_update_period` - alpha=1 - beta2_corr, - ) # shape is (batch_size, 1, 1, ...) - - # The 1st time we reach here is when size_step == 1. - size_step = (step + 1) // size_update_period - bias_correction2 = 1 - beta2_corr**size_step - # we don't bother with bias_correction1; this will help prevent divergence - # at the start of training. - - denom = scale_exp_avg_sq.sqrt() + eps - - scale_step = ( - -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom - ) - - is_too_small = param_rms < param_min_rms - - # when the param gets too small, just don't shrink it any further. - scale_step.masked_fill_(is_too_small, 0.0) - - # and ensure the parameter rms after update never exceeds param_max_rms. - # We have to look at the trained model for parameters at or around the - # param_max_rms, because sometimes they can indicate a problem with the - # topology or settings. - scale_step = torch.minimum(scale_step, - (param_max_rms - param_rms) / param_rms) - - delta = state["delta"] - # the factor of (1-beta1) relates to momentum. - delta.add_(p * scale_step, alpha=(1 - beta1)) - - def _step(self, group: dict, p: Tensor, state: dict): - """ - This function does the core update of self.step(), in the case where the members of - the batch have more than 1 element. - - Args: - group: A dict which will be used to look up configuration values - p: The parameter to be updated - grad: The grad of p - state: The state-dict corresponding to parameter p - - This function modifies p. - """ - grad = p.grad - lr = group["lr"] - beta1, beta2 = group["betas"] - eps = group["eps"] - param_min_rms = group["param_min_rms"] - step = state["step"] - - exp_avg_sq = state["exp_avg_sq"] - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) - - this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0) - bias_correction2 = 1 - beta2 ** (this_step + 1) - if bias_correction2 < 0.99: - # note: not in-place. - exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2) - - denom = exp_avg_sq.sqrt() - denom += eps - grad = grad / denom - - alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms) - - delta = state["delta"] - delta.add_(grad * alpha) - p.add_(delta) - - def _step_scalar(self, group: dict, p: Tensor, state: dict): - """ - A simplified form of the core update for scalar tensors, where we cannot get a good - estimate of the parameter rms. - """ - beta1, beta2 = group["betas"] - scalar_max = group["scalar_max"] - eps = group["eps"] - lr = group["lr"] * group["scalar_lr_scale"] - grad = p.grad - - exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - - # bias_correction2 is like in Adam. Don't bother with bias_correction1; - # slower update at the start will help stability anyway. - bias_correction2 = 1 - beta2 ** (state["step"] + 1) - denom = (exp_avg_sq / bias_correction2).sqrt() + eps - - delta = state["delta"] - delta.add_(grad / denom, alpha=-lr * (1 - beta1)) - p.clamp_(min=-scalar_max, max=scalar_max) - p.add_(delta) - - -class LRScheduler(object): - """ - Base-class for learning rate schedulers where the learning-rate depends on both the - batch and the epoch. - """ - - def __init__(self, optimizer: Optimizer, verbose: bool = False): - # Attach optimizer - if not isinstance(optimizer, Optimizer): - raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) - self.optimizer = optimizer - self.verbose = verbose - - for group in optimizer.param_groups: - group.setdefault("base_lr", group["lr"]) - - self.base_lrs = [group["base_lr"] for group in optimizer.param_groups] - - self.epoch = 0 - self.batch = 0 - - def state_dict(self): - """Returns the state of the scheduler as a :class:`dict`. - - It contains an entry for every variable in self.__dict__ which - is not the optimizer. - """ - return { - "base_lrs": self.base_lrs, - "epoch": self.epoch, - "batch": self.batch, - } - - def load_state_dict(self, state_dict): - """Loads the schedulers state. - - Args: - state_dict (dict): scheduler state. Should be an object returned - from a call to :meth:`state_dict`. - """ - self.__dict__.update(state_dict) - - def get_last_lr(self) -> List[float]: - """Return last computed learning rate by current scheduler. Will be a list of float.""" - return self._last_lr - - def get_lr(self): - # Compute list of learning rates from self.epoch and self.batch and - # self.base_lrs; this must be overloaded by the user. - # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ] - raise NotImplementedError - - def step_batch(self, batch: Optional[int] = None) -> None: - # Step the batch index, or just set it. If `batch` is specified, it - # must be the batch index from the start of training, i.e. summed over - # all epochs. - # You can call this in any order; if you don't provide 'batch', it should - # of course be called once per batch. - if batch is not None: - self.batch = batch - else: - self.batch = self.batch + 1 - self._set_lrs() - - def step_epoch(self, epoch: Optional[int] = None): - # Step the epoch index, or just set it. If you provide the 'epoch' arg, - # you should call this at the start of the epoch; if you don't provide the 'epoch' - # arg, you should call it at the end of the epoch. - if epoch is not None: - self.epoch = epoch - else: - self.epoch = self.epoch + 1 - self._set_lrs() - - def _set_lrs(self): - values = self.get_lr() - assert len(values) == len(self.optimizer.param_groups) - - for i, data in enumerate(zip(self.optimizer.param_groups, values)): - param_group, lr = data - param_group["lr"] = lr - self.print_lr(self.verbose, i, lr) - self._last_lr = [group["lr"] for group in self.optimizer.param_groups] - - def print_lr(self, is_verbose, group, lr): - """Display the current learning rate.""" - if is_verbose: - logging.info( - f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate" - f" of group {group} to {lr:.4e}." - ) - - -class Eden(LRScheduler): - """ - Eden scheduler. - The basic formula (before warmup) is: - lr = base_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 * - (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) * warmup - where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches - and then stays constant at 1. - - - E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam - - Args: - optimizer: the optimizer to change the learning rates on - lr_batches: the number of batches after which we start significantly - decreasing the learning rate, suggest 5000. - lr_epochs: the number of epochs after which we start significantly - decreasing the learning rate, suggest 6 if you plan to do e.g. - 20 to 40 epochs, but may need smaller number if dataset is huge - and you will do few epochs. - """ - - def __init__( - self, - optimizer: Optimizer, - lr_batches: Union[int, float], - lr_epochs: Union[int, float], - warmup_batches: Union[int, float] = 500.0, - warmup_start: float = 0.5, - verbose: bool = False, - ): - super(Eden, self).__init__(optimizer, verbose) - self.lr_batches = lr_batches - self.lr_epochs = lr_epochs - self.warmup_batches = warmup_batches - - assert 0.0 <= warmup_start <= 1.0, warmup_start - self.warmup_start = warmup_start - - def get_lr(self): - factor = ( - (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 - ) ** -0.25 * ( - ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 - ) - warmup_factor = ( - 1.0 - if self.batch >= self.warmup_batches - else self.warmup_start + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches) - # else 0.5 + 0.5 * (self.batch / self.warmup_batches) - ) - - return [x * factor * warmup_factor for x in self.base_lrs] - - -def _test_eden(): - m = torch.nn.Linear(100, 100) - optim = ScaledAdam(m.parameters(), lr=0.03) - - scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True) - - for epoch in range(10): - scheduler.step_epoch(epoch) # sets epoch to `epoch` - - for step in range(20): - x = torch.randn(200, 100).detach() - x.requires_grad = True - y = m(x) - dy = torch.randn(200, 100).detach() - f = (y * dy).sum() - f.backward() - - optim.step() - scheduler.step_batch() - optim.zero_grad() - - logging.info(f"last lr = {scheduler.get_last_lr()}") - logging.info(f"state dict = {scheduler.state_dict()}") - - -# This is included mostly as a baseline for ScaledAdam. -class Eve(Optimizer): - """ - Implements Eve algorithm. This is a modified version of AdamW with a special - way of setting the weight-decay / shrinkage-factor, which is designed to make the - rms of the parameters approach a particular target_rms (default: 0.1). This is - for use with networks with 'scaled' versions of modules (see scaling.py), which - will be close to invariant to the absolute scale on the parameter matrix. - - The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. - The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. - Eve is unpublished so far. - - Arguments: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay coefficient (default: 3e-4; - this value means that the weight would decay significantly after - about 3k minibatches. Is not multiplied by learning rate, but - is conditional on RMS-value of parameter being > target_rms. - target_rms (float, optional): target root-mean-square value of - parameters, if they fall below this we will stop applying weight decay. - - - .. _Adam: A Method for Stochastic Optimization: - https://arxiv.org/abs/1412.6980 - .. _Decoupled Weight Decay Regularization: - https://arxiv.org/abs/1711.05101 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ - """ - - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.98), - eps=1e-8, - weight_decay=1e-3, - target_rms=0.1, - ): - if not 0.0 <= lr: - raise ValueError("Invalid learning rate: {}".format(lr)) - if not 0.0 <= eps: - raise ValueError("Invalid epsilon value: {}".format(eps)) - if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) - if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) - if not 0 <= weight_decay <= 0.1: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) - if not 0 < target_rms <= 10.0: - raise ValueError("Invalid target_rms value: {}".format(target_rms)) - defaults = dict( - lr=lr, - betas=betas, - eps=eps, - weight_decay=weight_decay, - target_rms=target_rms, - ) - super(Eve, self).__init__(params, defaults) - - def __setstate__(self, state): - super(Eve, self).__setstate__(state) - - @torch.no_grad() - def step(self, closure=None): - """Performs a single optimization step. - - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - for p in group["params"]: - if p.grad is None: - continue - - # Perform optimization step - grad = p.grad - if grad.is_sparse: - raise RuntimeError("AdamW does not support sparse gradients") - - state = self.state[p] - - # State initialization - if len(state) == 0: - state["step"] = 0 - # Exponential moving average of gradient values - state["exp_avg"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) - # Exponential moving average of squared gradient values - state["exp_avg_sq"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) - - exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] - - beta1, beta2 = group["betas"] - - state["step"] += 1 - bias_correction1 = 1 - beta1 ** state["step"] - bias_correction2 = 1 - beta2 ** state["step"] - - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( - group["eps"] - ) - - step_size = group["lr"] / bias_correction1 - target_rms = group["target_rms"] - weight_decay = group["weight_decay"] - - if p.numel() > 1: - # avoid applying this weight-decay on "scaling factors" - # (which are scalar). - is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) - p.mul_(1 - (weight_decay * is_above_target_rms)) - - p.addcdiv_(exp_avg, denom, value=-step_size) - - if random.random() < 0.0005: - step = (exp_avg / denom) * step_size - logging.info( - f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}" - ) - - return loss - - -def _test_scaled_adam(hidden_dim: int): - import timeit - - from scaling import ScaledLinear - - E = 100 - B = 4 - T = 2 - logging.info("in test_eve_cain") - # device = torch.device('cuda') - device = torch.device("cpu") - dtype = torch.float32 - - fix_random_seed(42) - # these input_magnitudes and output_magnitudes are to test that - # Abel is working as we expect and is able to adjust scales of - # different dims differently. - input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() - output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() - - for iter in [1, 0]: - fix_random_seed(42) - Linear = torch.nn.Linear if iter == 0 else ScaledLinear - - m = torch.nn.Sequential( - Linear(E, hidden_dim), - torch.nn.PReLU(), - Linear(hidden_dim, hidden_dim), - torch.nn.PReLU(), - Linear(hidden_dim, E), - ).to(device) - - train_pairs = [ - ( - 100.0 - * torch.randn(B, T, E, device=device, dtype=dtype) - * input_magnitudes, - torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes, - ) - for _ in range(20) - ] - - if iter == 0: - optim = Eve(m.parameters(), lr=0.003) - elif iter == 1: - optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) - scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) - - start = timeit.default_timer() - avg_loss = 0.0 - for epoch in range(180): - scheduler.step_epoch() - # if epoch == 100 and iter in [2,3]: - # optim.reset_speedup() # check it doesn't crash. - - # if epoch == 130: - # opts = diagnostics.TensorDiagnosticOptions( - # 2 ** 22 - # ) # allow 4 megabytes per sub-module - # diagnostic = diagnostics.attach_diagnostics(m, opts) - - for n, (x, y) in enumerate(train_pairs): - y_out = m(x) - loss = ((y_out - y) ** 2).mean() * 100.0 - if epoch == 0 and n == 0: - avg_loss = loss.item() - else: - avg_loss = 0.98 * avg_loss + 0.02 * loss.item() - if n == 0 and epoch % 5 == 0: - # norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() - # norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() - # norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() - # norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() - # scale1 = '%.2e' % (m[0].weight_scale.exp().item()) - # scale1b = '%.2e' % (m[0].bias_scale.exp().item()) - # scale2 = '%.2e' % (m[2].weight_scale.exp().item()) - # scale2b = '%.2e' % (m[2].bias_scale.exp().item()) - lr = scheduler.get_last_lr()[0] - logging.info( - f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}" - ) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} - loss.log().backward() - optim.step() - optim.zero_grad() - scheduler.step_batch() - - # diagnostic.print_diagnostics() - - stop = timeit.default_timer() - logging.info(f"Iter={iter}, Time taken: {stop - start}") - - logging.info(f"last lr = {scheduler.get_last_lr()}") - # logging.info("state dict = ", scheduler.state_dict()) - # logging.info("optim state_dict = ", optim.state_dict()) - logging.info(f"input_magnitudes = {input_magnitudes}") - logging.info(f"output_magnitudes = {output_magnitudes}") - - -if __name__ == "__main__": - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - logging.getLogger().setLevel(logging.INFO) - import subprocess - - s = subprocess.check_output( - "git status -uno .; git log -1; git diff HEAD .", shell=True - ) - logging.info(s) - import sys - - if len(sys.argv) > 1: - hidden_dim = int(sys.argv[1]) - else: - hidden_dim = 200 - - _test_scaled_adam(hidden_dim) - _test_eden() diff --git a/egs/iwslt22_ta/ASR/zipformer/optim.py b/egs/iwslt22_ta/ASR/zipformer/optim.py new file mode 120000 index 000000000..5eaa3cffd --- /dev/null +++ b/egs/iwslt22_ta/ASR/zipformer/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/zipformer/pretrained.py b/egs/iwslt22_ta/ASR/zipformer/pretrained.py deleted file mode 100755 index a4b7c2c36..000000000 --- a/egs/iwslt22_ta/ASR/zipformer/pretrained.py +++ /dev/null @@ -1,382 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads a checkpoint and uses it to decode waves. -You can generate the checkpoint with the following command: - -- For non-streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 30 \ - --avg 9 - -- For streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --causal 1 \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 30 \ - --avg 9 - -Usage of this script: - -- For non-streaming model: - -(1) greedy search -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) modified beam search -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method modified_beam_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) fast beam search -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method fast_beam_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -- For streaming model: - -(1) greedy search -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) modified beam search -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method modified_beam_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) fast beam search -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method fast_beam_search \ - /path/to/foo.wav \ - /path/to/bar.wav - - -You can also use `./zipformer/exp/epoch-xx.pt`. - -Note: ./zipformer/exp/pretrained.pt is generated by ./zipformer/export.py -""" - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import sentencepiece as spm -import torch -import torchaudio -from beam_search import ( - fast_beam_search_one_best, - greedy_search_batch, - modified_beam_search, -) -from icefall.utils import make_pad_mask -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--bpe-model", - type=str, - help="""Path to bpe.model.""", - ) - - parser.add_argument( - "--method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. Used only when - --method is greedy_search. - """, - ) - - add_model_arguments(parser) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - - params.update(vars(args)) - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - if params.causal: - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - - logging.info("Creating model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - feature_lengths = torch.tensor(feature_lengths, device=device) - - # model forward - x, x_lens = model.encoder_embed(features, feature_lengths) - - src_key_padding_mask = make_pad_mask(x_lens) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - encoder_out, encoder_out_lens = model.encoder( - x, x_lens, src_key_padding_mask - ) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - hyps = [] - msg = f"Using {params.method}" - logging.info(msg) - - if params.method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - else: - raise ValueError(f"Unsupported method: {params.method}") - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/iwslt22_ta/ASR/zipformer/pretrained.py b/egs/iwslt22_ta/ASR/zipformer/pretrained.py new file mode 120000 index 000000000..8096cbfe9 --- /dev/null +++ b/egs/iwslt22_ta/ASR/zipformer/pretrained.py @@ -0,0 +1 @@ +../../ST/zipformer/pretrained.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/zipformer/scaling.py b/egs/iwslt22_ta/ASR/zipformer/scaling.py deleted file mode 100644 index 908b60938..000000000 --- a/egs/iwslt22_ta/ASR/zipformer/scaling.py +++ /dev/null @@ -1,1797 +0,0 @@ -# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import Optional, Tuple, Union -import logging -import k2 -from torch.cuda.amp import custom_fwd, custom_bwd -import random -import torch -import math -import torch.nn as nn -from torch import Tensor - - -class PiecewiseLinear(object): - """ - Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with - the x values in order. x values <[initial x] or >[final x] are map to [initial y], [final y] - respectively. - """ - def __init__(self, *args): - assert len(args) >= 1, len(args) - if len(args) == 1 and isinstance(args[0], PiecewiseLinear): - self.pairs = list(args[0].pairs) - else: - self.pairs = [ (float(x), float(y)) for x,y in args ] - for (x,y) in self.pairs: - assert isinstance(x, (float, int)), type(x) - assert isinstance(y, (float, int)), type(y) - - for i in range(len(self.pairs) - 1): - assert self.pairs[i + 1][0] > self.pairs[i][0], (i, self.pairs[i], self.pairs[i + 1]) - - def __str__(self): - # e.g. 'PiecewiseLinear((0., 10.), (100., 0.))' - return f'PiecewiseLinear({str(self.pairs)[1:-1]})' - - def __call__(self, x): - if x <= self.pairs[0][0]: - return self.pairs[0][1] - elif x >= self.pairs[-1][0]: - return self.pairs[-1][1] - else: - cur_x, cur_y = self.pairs[0] - for i in range(1, len(self.pairs)): - next_x, next_y = self.pairs[i] - if x >= cur_x and x <= next_x: - return cur_y + (next_y - cur_y) * (x - cur_x) / (next_x - cur_x) - cur_x, cur_y = next_x, next_y - assert False - - def __mul__(self, alpha): - return PiecewiseLinear( - * [(x, y * alpha) for x, y in self.pairs]) - - def __add__(self, x): - if isinstance(x, (float, int)): - return PiecewiseLinear( - * [(p[0], p[1] + x) for p in self.pairs]) - s, x = self.get_common_basis(x) - return PiecewiseLinear( - * [(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)]) - - def max(self, x): - if isinstance(x, (float, int)): - x = PiecewiseLinear( (0, x) ) - s, x = self.get_common_basis(x, include_crossings=True) - return PiecewiseLinear( - * [(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]) - - def min(self, x): - if isinstance(x, float) or isinstance(x, int): - x = PiecewiseLinear( (0, x) ) - s, x = self.get_common_basis(x, include_crossings=True) - return PiecewiseLinear( - * [ (sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]) - - def __eq__(self, other): - return self.pairs == other.pairs - - def get_common_basis(self, - p: 'PiecewiseLinear', - include_crossings: bool = False): - """ - Returns (self_mod, p_mod) which are equivalent piecewise lienar - functions to self and p, but with the same x values. - - p: the other piecewise linear function - include_crossings: if true, include in the x values positions - where the functions indicate by this and p crosss. - """ - assert isinstance(p, PiecewiseLinear), type(p) - - # get sorted x-values without repetition. - x_vals = sorted(set([ x for x, _ in self.pairs ] + [ x for x, _ in p.pairs ])) - y_vals1 = [ self(x) for x in x_vals ] - y_vals2 = [ p(x) for x in x_vals ] - - if include_crossings: - extra_x_vals = [] - for i in range(len(x_vals) - 1): - if (y_vals1[i] > y_vals2[i]) != (y_vals1[i+1] > y_vals2[i+1]): - # if the two lines in this subsegment potentially cross each other.. - diff_cur = abs(y_vals1[i] - y_vals2[i]) - diff_next = abs(y_vals1[i+1] - y_vals2[i+1]) - # `pos`, between 0 and 1, gives the relative x position, - # with 0 being x_vals[i] and 1 being x_vals[i+1]. - pos = diff_cur / (diff_cur + diff_next) - extra_x_val = x_vals[i] + pos * (x_vals[i+1] - x_vals[i]) - extra_x_vals.append(extra_x_val) - if len(extra_x_vals) > 0: - x_vals = sorted(set(x_vals + extra_x_vals)) - y_vals1 = [ self(x) for x in x_vals ] - y_vals2 = [ p(x) for x in x_vals ] - return ( PiecewiseLinear(* zip(x_vals, y_vals1)), - PiecewiseLinear(* zip(x_vals, y_vals2)) ) - - -class ScheduledFloat(torch.nn.Module): - """ - This object is a torch.nn.Module only because we want it to show up in [top_level module].modules(); - it does not have a working forward() function. You are supposed to cast it to float, as - in, float(parent_module.whatever), and use it as something like a dropout prob. - - It is a floating point value whose value changes depending on the batch count of the - training loop. It is a piecewise linear function where you specifiy the (x,y) pairs - in sorted order on x; x corresponds to the batch index. For batch-index values before the - first x or after the last x, we just use the first or last y value. - - Example: - self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0) - - `default` is used when self.batch_count is not set or not in training mode or in - torch.jit scripting mode. - """ - def __init__(self, - *args, - default: float = 0.0): - super().__init__() - # self.batch_count and self.name will be written to in the training loop. - self.batch_count = None - self.name = None - self.default = default - self.schedule = PiecewiseLinear(*args) - - def extra_repr(self) -> str: - return f'batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}' - - def __float__(self): - batch_count = self.batch_count - if batch_count is None or not self.training or torch.jit.is_scripting(): - return float(self.default) - else: - ans = self.schedule(self.batch_count) - if random.random() < 0.0002: - logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}") - return ans - - def __add__(self, x): - if isinstance(x, float) or isinstance(x, int): - return ScheduledFloat(self.schedule + x, - default=self.default) - else: - return ScheduledFloat(self.schedule + x.schedule, - default=self.default+x.default) - - def max(self, x): - if isinstance(x, float) or isinstance(x, int): - return ScheduledFloat(self.schedule.max(x), - default=self.default) - else: - return ScheduledFloat(self.schedule.max(x.schedule), - default=max(self.default, x.default)) - - -FloatLike = Union[float, ScheduledFloat] - - -def random_cast_to_half(x: Tensor, - min_abs: float = 5.0e-06) -> Tensor: - """ - A randomized way of casting a floating point value to half precision. - """ - if x.dtype == torch.float16: - return x - x_abs = x.abs() - is_too_small = (x_abs < min_abs) - # for elements where is_too_small is true, random_val will contain +-min_abs with - # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations, - # for those elements]. - random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs) - return torch.where(is_too_small, random_val, x).to(torch.float16) - - -class CutoffEstimator: - """ - Estimates cutoffs of an arbitrary numerical quantity such that a specified - proportion of items will be above the cutoff on average. - - p is the proportion of items that should be above the cutoff. - """ - def __init__(self, p: float): - self.p = p - # total count of items - self.count = 0 - # total count of items that were above the cutoff - self.count_above = 0 - # initial cutoff value - self.cutoff = 0 - - def __call__(self, x: float) -> bool: - """ - Returns true if x is above the cutoff. - """ - ans = (x > self.cutoff) - self.count += 1 - if ans: - self.count_above += 1 - cur_p = self.count_above / self.count - delta_p = cur_p - self.p - if (delta_p > 0) == ans: - q = abs(delta_p) - self.cutoff = x * q + self.cutoff * (1-q) - return ans - - -class SoftmaxFunction(torch.autograd.Function): - """ - Tries to handle half-precision derivatives in a randomized way that should - be more accurate for training than the default behavior. - """ - @staticmethod - def forward(ctx, x: Tensor, dim: int): - ans = x.softmax(dim=dim) - # if x dtype is float16, x.softmax() returns a float32 because - # (presumably) that op does not support float16, and autocast - # is enabled. - if torch.is_autocast_enabled(): - ans = ans.to(torch.float16) - ctx.save_for_backward(ans) - ctx.x_dtype = x.dtype - ctx.dim = dim - return ans - - @staticmethod - def backward(ctx, ans_grad: Tensor): - ans, = ctx.saved_tensors - with torch.cuda.amp.autocast(enabled=False): - ans_grad = ans_grad.to(torch.float32) - ans = ans.to(torch.float32) - x_grad = ans_grad * ans - x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) - return x_grad, None - - -def softmax(x: Tensor, dim: int): - if not x.requires_grad or torch.jit.is_scripting(): - return x.softmax(dim=dim) - - return SoftmaxFunction.apply(x, dim) - - -class MaxEigLimiterFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x: Tensor, - coeffs: Tensor, - direction: Tensor, - channel_dim: int, - grad_scale: float) -> Tensor: - ctx.channel_dim = channel_dim - ctx.grad_scale = grad_scale - ctx.save_for_backward(x.detach(), - coeffs.detach(), - direction.detach()) - return x - - @staticmethod - def backward(ctx, x_grad, *args): - with torch.enable_grad(): - (x_orig, coeffs, new_direction) = ctx.saved_tensors - x_orig.requires_grad = True - num_channels = x_orig.shape[ctx.channel_dim] - x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) - new_direction.requires_grad = False - x = x - x.mean(dim=0) - x_var = (x ** 2).mean() - x_residual = x - coeffs * new_direction - x_residual_var = (x_residual ** 2).mean() - # `variance_proportion` is the proportion of the variance accounted for - # by the top eigen-direction. This is to be minimized. - variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) - variance_proportion.backward() - x_orig_grad = x_orig.grad - x_extra_grad = x_orig.grad * ctx.grad_scale * x_grad.norm() / (x_orig_grad.norm() + 1.0e-20) - return x_grad + x_extra_grad.detach(), None, None, None, None - - -class BiasNormFunction(torch.autograd.Function): - # This computes: - # scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp() - # return (x - bias) * scales - # (after unsqueezing the bias), but it does it in a memory-efficient way so that - # it can just store the returned value (chances are, this will also be needed for - # some other reason, related to the next operation, so we can save memory). - @staticmethod - def forward(ctx, x: Tensor, bias: Tensor, log_scale: Tensor, channel_dim: int, - store_output_for_backprop: bool) -> Tensor: - assert bias.ndim == 1 - if channel_dim < 0: - channel_dim = channel_dim + x.ndim - ctx.store_output_for_backprop = store_output_for_backprop - ctx.channel_dim = channel_dim - for _ in range(channel_dim + 1, x.ndim): - bias = bias.unsqueeze(-1) - scales = (torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5) * log_scale.exp() - ans = x * scales - ctx.save_for_backward(ans.detach() if store_output_for_backprop else x, - scales.detach(), bias.detach(), log_scale.detach()) - return ans - - @staticmethod - def backward(ctx, ans_grad: Tensor) -> Tensor: - ans_or_x, scales, bias, log_scale = ctx.saved_tensors - if ctx.store_output_for_backprop: - x = ans_or_x / scales - else: - x = ans_or_x - x = x.detach() - x.requires_grad = True - bias.requires_grad = True - log_scale.requires_grad = True - with torch.enable_grad(): - # recompute scales from x, bias and log_scale. - scales = (torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5) * log_scale.exp() - ans = x * scales - ans.backward(gradient=ans_grad) - return x.grad, bias.grad.flatten(), log_scale.grad, None, None - - -class BiasNorm(torch.nn.Module): - """ - This is intended to be a simpler, and hopefully cheaper, replacement for - LayerNorm. The observation this is based on, is that Transformer-type - networks, especially with pre-norm, sometimes seem to set one of the - feature dimensions to a large constant value (e.g. 50), which "defeats" - the LayerNorm because the output magnitude is then not strongly dependent - on the other (useful) features. Presumably the weight and bias of the - LayerNorm are required to allow it to do this. - - Instead, we give the BiasNorm a trainable bias that it can use when - computing the scale for normalization. We also give it a (scalar) - trainable scale on the output. - - - Args: - num_channels: the number of channels, e.g. 512. - channel_dim: the axis/dimension corresponding to the channel, - interprted as an offset from the input's ndim if negative. - shis is NOT the num_channels; it should typically be one of - {-2, -1, 0, 1, 2, 3}. - log_scale: the initial log-scale that we multiply the output by; this - is learnable. - log_scale_min: FloatLike, minimum allowed value of log_scale - log_scale_max: FloatLike, maximum allowed value of log_scale - store_output_for_backprop: only possibly affects memory use; recommend - to set to True if you think the output of this module is more likely - than the input of this module to be required to be stored for the - backprop. - """ - def __init__( - self, - num_channels: int, - channel_dim: int = -1, # CAUTION: see documentation. - log_scale: float = 1.0, - log_scale_min: float = -1.5, - log_scale_max: float = 1.5, - store_output_for_backprop: bool = False - ) -> None: - super(BiasNorm, self).__init__() - self.num_channels = num_channels - self.channel_dim = channel_dim - self.log_scale = nn.Parameter(torch.tensor(log_scale)) - self.bias = nn.Parameter(torch.zeros(num_channels)) - - self.log_scale_min = log_scale_min - self.log_scale_max = log_scale_max - - self.store_output_for_backprop = store_output_for_backprop - - def forward(self, x: Tensor) -> Tensor: - assert x.shape[self.channel_dim] == self.num_channels - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - channel_dim = self.channel_dim - if channel_dim < 0: - channel_dim += x.ndim - bias = self.bias - for _ in range(channel_dim + 1, x.ndim): - bias = bias.unsqueeze(-1) - scales = ((torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5) * - self.log_scale.exp()) - return x * scales - - log_scale = limit_param_value(self.log_scale, - min=float(self.log_scale_min), - max=float(self.log_scale_max), - training=self.training) - - return BiasNormFunction.apply(x, self.bias, log_scale, - self.channel_dim, - self.store_output_for_backprop) - - -def ScaledLinear(*args, - initial_scale: float = 1.0, - **kwargs) -> nn.Linear: - """ - Behaves like a constructor of a modified version of nn.Linear - that gives an easy way to set the default initial parameter scale. - - Args: - Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False. - - initial_scale: you can override this if you want to increase - or decrease the initial magnitude of the module's output - (affects the initialization of weight_scale and bias_scale). - Another option, if you want to do something like this, is - to re-initialize the parameters. - """ - ans = nn.Linear(*args, **kwargs) - with torch.no_grad(): - ans.weight[:] *= initial_scale - if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) - return ans - - -def ScaledConv1d(*args, - initial_scale: float = 1.0, - **kwargs) -> nn.Conv1d: - """ - Behaves like a constructor of a modified version of nn.Conv1d - that gives an easy way to set the default initial parameter scale. - - Args: - Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False. - - initial_scale: you can override this if you want to increase - or decrease the initial magnitude of the module's output - (affects the initialization of weight_scale and bias_scale). - Another option, if you want to do something like this, is - to re-initialize the parameters. - """ - ans = nn.Conv1d(*args, **kwargs) - with torch.no_grad(): - ans.weight[:] *= initial_scale - if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) - return ans - - -def ScaledConv2d(*args, - initial_scale: float = 1.0, - **kwargs) -> nn.Conv2d: - """ - Behaves like a constructor of a modified version of nn.Conv2d - that gives an easy way to set the default initial parameter scale. - - Args: - Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False, but: - NO PADDING-RELATED ARGS. - - initial_scale: you can override this if you want to increase - or decrease the initial magnitude of the module's output - (affects the initialization of weight_scale and bias_scale). - Another option, if you want to do something like this, is - to re-initialize the parameters. - """ - ans = nn.Conv2d(*args, **kwargs) - with torch.no_grad(): - ans.weight[:] *= initial_scale - if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) - return ans - - -class ChunkCausalDepthwiseConv1d(torch.nn.Module): - """ - Behaves like a depthwise 1d convolution, except that it is causal in - a chunkwise way, as if we had a block-triangular attention mask. - The chunk size is provided at test time (it should probably be - kept in sync with the attention mask). - - This has a little more than twice the parameters of a conventional - depthwise conv1d module: we implement it by having one - depthwise convolution, of half the width, that is causal (via - right-padding); and one depthwise convolution that is applied only - within chunks, that we multiply by a scaling factor which depends - on the position within the chunk. - - Args: - Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False. - - initial_scale: you can override this if you want to increase - or decrease the initial magnitude of the module's output - (affects the initialization of weight_scale and bias_scale). - Another option, if you want to do something like this, is - to re-initialize the parameters. - """ - def __init__(self, - channels: int, - kernel_size: int, - initial_scale: float = 1.0, - bias: bool = True): - super().__init__() - assert kernel_size % 2 == 1 - - half_kernel_size = (kernel_size + 1) // 2 - # will pad manually, on one side. - self.causal_conv = nn.Conv1d(in_channels=channels, - out_channels=channels, - groups=channels, - kernel_size=half_kernel_size, - padding=0, - bias=True) - - self.chunkwise_conv = nn.Conv1d(in_channels=channels, - out_channels=channels, - groups=channels, - kernel_size=kernel_size, - padding=kernel_size // 2, - bias=bias) - - # first row is correction factors added to the scale near the left edge of the chunk, - # second row is correction factors added to the scale near the right edge of the chunk, - # both of these are added to a default scale of 1.0. - self.chunkwise_conv_scale = nn.Parameter(torch.zeros(2, channels, kernel_size)) - self.kernel_size = kernel_size - - with torch.no_grad(): - self.causal_conv.weight[:] *= initial_scale - self.chunkwise_conv.weight[:] *= initial_scale - if bias: - torch.nn.init.uniform_(self.causal_conv.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) - - def forward(self, - x: Tensor, - chunk_size: int = -1) -> Tensor: - """ - Forward function. Args: - x: a Tensor of shape (batch_size, channels, seq_len) - chunk_size: the chunk size, in frames; does not have to divide seq_len exactly. - """ - (batch_size, num_channels, seq_len) = x.shape - - # half_kernel_size = self.kernel_size + 1 // 2 - # left_pad is half_kernel_size - 1 where half_kernel_size is the size used - # in the causal conv. It's the amount by which we must pad on the left, - # to make the convolution causal. - left_pad = self.kernel_size // 2 - - if chunk_size < 0 or chunk_size > seq_len: - chunk_size = seq_len - right_pad = -seq_len % chunk_size - - x = torch.nn.functional.pad(x, (left_pad, right_pad)) - - x_causal = self.causal_conv(x[..., :left_pad + seq_len]) - assert x_causal.shape == (batch_size, num_channels, seq_len) - - x_chunk = x[..., left_pad:] - num_chunks = x_chunk.shape[2] // chunk_size - x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks, chunk_size) - x_chunk = x_chunk.permute(0, 2, 1, 3).reshape(batch_size * num_chunks, - num_channels, chunk_size) - x_chunk = self.chunkwise_conv(x_chunk) # does not change shape - - chunk_scale = self._get_chunk_scale(chunk_size) - - x_chunk = x_chunk * chunk_scale - x_chunk = x_chunk.reshape(batch_size, num_chunks, - num_channels, chunk_size).permute(0, 2, 1, 3) - x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks * chunk_size)[..., :seq_len] - - return x_chunk + x_causal - - def _get_chunk_scale(self, chunk_size: int): - """Returns tensor of shape (num_channels, chunk_size) that will be used to - scale the output of self.chunkwise_conv.""" - left_edge = self.chunkwise_conv_scale[0] - right_edge = self.chunkwise_conv_scale[1] - if chunk_size < self.kernel_size: - left_edge = left_edge[:, :chunk_size] - right_edge = right_edge[:, -chunk_size:] - else: - t = chunk_size - self.kernel_size - channels = left_edge.shape[0] - pad = torch.zeros(channels, t, - device=left_edge.device, - dtype=left_edge.dtype) - left_edge = torch.cat((left_edge, pad), dim=-1) - right_edge = torch.cat((pad, right_edge), dim=-1) - return 1.0 + (left_edge + right_edge) - - def streaming_forward( - self, - x: Tensor, - cache: Tensor, - ) -> Tuple[Tensor, Tensor]: - """Streaming Forward function. - - Args: - x: a Tensor of shape (batch_size, channels, seq_len) - cache: cached left context of shape (batch_size, channels, left_pad) - """ - (batch_size, num_channels, seq_len) = x.shape - - # left_pad is half_kernel_size - 1 where half_kernel_size is the size used - # in the causal conv. It's the amount by which we must pad on the left, - # to make the convolution causal. - left_pad = self.kernel_size // 2 - - # Pad cache - assert cache.shape[-1] == left_pad, (cache.shape[-1], left_pad) - x = torch.cat([cache, x], dim=2) - # Update cache - cache = x[..., -left_pad:] - - x_causal = self.causal_conv(x) - assert x_causal.shape == (batch_size, num_channels, seq_len) - - x_chunk = x[..., left_pad:] - x_chunk = self.chunkwise_conv(x_chunk) # does not change shape - - chunk_scale = self._get_chunk_scale(chunk_size=seq_len) - x_chunk = x_chunk * chunk_scale - - return x_chunk + x_causal, cache - - -class BalancerFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x: Tensor, - min_mean: float, - max_mean: float, - min_rms: float, - max_rms: float, - grad_scale: float, - channel_dim: int, - ) -> Tensor: - if channel_dim < 0: - channel_dim += x.ndim - ctx.channel_dim = channel_dim - ctx.save_for_backward(x) - ctx.config = (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) - return x - - @staticmethod - def backward( - ctx, x_grad: Tensor - ) -> Tuple[Tensor, None, None, None, None, None]: - x, = ctx.saved_tensors - (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config - - try: - with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): - x = x.to(torch.float32) - x = x.detach() - x.requires_grad = True - mean_dims = [ i for i in range(x.ndim) if i != channel_dim ] - uncentered_var = (x ** 2).mean(dim=mean_dims, keepdim=True) - mean = x.mean(dim=mean_dims, keepdim=True) - stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt() - rms = uncentered_var.clamp(min=1.0e-20).sqrt() - - m = mean / stddev - # part of loss that relates to mean / stddev - m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs() - - # put a much larger scale on the RMS-max-limit loss, so that if both it and the - # m_loss are violated we fix the RMS loss first. - rms_clamped = rms.clamp(min=min_rms, max=max_rms) - r_loss = (rms_clamped / rms).log().abs() - - loss = (m_loss + r_loss) - - loss.backward(gradient=torch.ones_like(loss)) - loss_grad = x.grad - loss_grad_rms = (loss_grad ** 2).mean(dim=mean_dims, keepdim=True).sqrt().clamp(min=1.0e-20) - - loss_grad = loss_grad * (grad_scale / loss_grad_rms) - - x_grad_float = x_grad.to(torch.float32) - # scale each element of loss_grad by the absolute value of the corresponding - # element of x_grad, which we view as a noisy estimate of its magnitude for that - # (frame and dimension). later we can consider factored versions. - x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad) - x_grad = x_grad_mod.to(x_grad.dtype) - except Exception as e: - logging.info(f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue.") - - return x_grad, None, None, None, None, None, None - - -class Balancer(torch.nn.Module): - """ - Modifies the backpropped derivatives of a function to try to encourage, for - each channel, that it is positive at least a proportion `threshold` of the - time. It does this by multiplying negative derivative values by up to - (1+max_factor), and positive derivative values by up to (1-max_factor), - interpolated from 1 at the threshold to those extremal values when none - of the inputs are positive. - - Args: - num_channels: the number of channels - channel_dim: the dimension/axis corresponding to the channel, e.g. - -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. - min_positive: the minimum, per channel, of the proportion of the time - that (x > 0), below which we start to modify the derivatives. - max_positive: the maximum, per channel, of the proportion of the time - that (x > 0), above which we start to modify the derivatives. - scale_gain_factor: determines the 'gain' with which we increase the - change in gradient once the constraints on min_abs and max_abs - are violated. - min_abs: the minimum average-absolute-value difference from the mean - value per channel, which we allow, before we start to modify - the derivatives to prevent this. - max_abs: the maximum average-absolute-value difference from the mean - value per channel, which we allow, before we start to modify - the derivatives to prevent this. - prob: determines the minimum probability with which we modify the - gradients for the {min,max}_positive and {min,max}_abs constraints, - on each forward(). This is done randomly to prevent all layers - from doing it at the same time. - """ - def __init__( - self, - num_channels: int, - channel_dim: int, - min_positive: FloatLike = 0.05, - max_positive: FloatLike = 0.95, - min_abs: FloatLike = 0.2, - max_abs: FloatLike = 100.0, - grad_scale: FloatLike = 0.04, - prob: Optional[FloatLike] = None, - ): - super().__init__() - - if prob is None: - prob = ScheduledFloat((0.0, 0.5), (8000.0, 0.125), default=0.4) - self.prob = prob - # 5% of the time we will return and do nothing because memory usage is - # too high. - self.mem_cutoff = CutoffEstimator(0.05) - - # actually self.num_channels is no longer needed except for an assertion. - self.num_channels = num_channels - self.channel_dim = channel_dim - self.min_positive = min_positive - self.max_positive = max_positive - self.min_abs = min_abs - self.max_abs = max_abs - self.grad_scale = grad_scale - - def forward(self, x: Tensor) -> Tensor: - if (torch.jit.is_scripting() or not x.requires_grad or - (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))): - return _no_op(x) - - prob = float(self.prob) - if random.random() < prob: - # The following inner-functions convert from the way we historically specified - # these limitations, as limits on the absolute value and the proportion of positive - # values, to limits on the RMS value and the (mean / stddev). - def _abs_to_rms(x): - # for normally distributed data, if the expected absolute value is x, the - # expected rms value will be sqrt(pi/2) * x. - return 1.25331413732 * x - - def _proportion_positive_to_mean(x): - def _atanh(x): - eps = 1.0e-10 - # eps is to prevent crashes if x is exactly 0 or 1. - # we'll just end up returning a fairly large value. - return (math.log (1+x+eps) - math.log (1-x+eps)) / 2. - - def _approx_inverse_erf(x): - # 1 / (sqrt(pi) * ln(2)), - # see https://math.stackexchange.com/questions/321569/approximating-the-error-function-erf-by-analytical-functions - # this approximation is extremely crude and gets progressively worse for - # x very close to -1 or +1, but we mostly care about the "middle" region - # e.g. _approx_inverse_erf(0.05) = 0.0407316414078772, - # and math.erf(0.0407316414078772) = 0.045935330944660666, - # which is pretty close to 0.05. - return 0.8139535143 * _atanh(x) - # first convert x from the range 0..1 to the range -1..1 which the error - # function returns - x = -1 + (2 * x) - return _approx_inverse_erf(x) - - min_mean = _proportion_positive_to_mean(float(self.min_positive)) - max_mean = _proportion_positive_to_mean(float(self.max_positive)) - min_rms = _abs_to_rms(float(self.min_abs)) - max_rms = _abs_to_rms(float(self.max_abs)) - grad_scale = float(self.grad_scale) - - assert x.shape[self.channel_dim] == self.num_channels - - return BalancerFunction.apply( - x, min_mean, max_mean, min_rms, max_rms, grad_scale, self.channel_dim - ) - else: - return _no_op(x) - - -def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float, - name: str = None) -> Tensor: - """ - Returns x unmodified, but in backprop will put a penalty for the excess of - the absolute values of elements of x over the limit "limit". E.g. if - limit == 10.0, then if x has any values over 10 it will get a penalty. - - Caution: the value of this penalty will be affected by grad scaling used - in automatic mixed precision training. For this reasons we use this, - it shouldn't really matter, or may even be helpful; we just use this - to disallow really implausible values of scores to be given to softmax. - - The name is for randomly printed debug info. - """ - x_sign = x.sign() - over_limit = (x.abs() - limit) > 0 - # The following is a memory efficient way to penalize the absolute values of - # x that's over the limit. (The memory efficiency comes when you think - # about which items torch needs to cache for the autograd, and which ones it - # can throw away). The numerical value of aux_loss as computed here will - # actually be larger than it should be, by limit * over_limit.sum(), but it - # has the same derivative as the real aux_loss which is penalty * (x.abs() - - # limit).relu(). - aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x) - # note: we don't do sum() here on aux)_loss, but it's as if we had done - # sum() due to how with_loss() works. - x = with_loss(x, aux_loss, name) - # you must use x for something, or this will be ineffective. - return x - - -def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. - if x.ndim == 2: - return x.diag() - else: - (batch, dim, dim) = x.shape - x = x.reshape(batch, dim * dim) - x = x[:, ::dim+1] - assert x.shape == (batch, dim) - return x - - -def _whitening_metric(x: Tensor, - num_groups: int): - """ - Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of - of the centered feature covariance are the same within each group's covariance matrix - and also between groups. - Args: - x: a Tensor of shape (*, num_channels) - num_groups: the number of groups of channels, a number >=1 that divides num_channels - Returns: - Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and - greater than 1.0 otherwise. - """ - assert x.dtype != torch.float16 - x = x.reshape(-1, x.shape[-1]) - (num_frames, num_channels) = x.shape - assert num_channels % num_groups == 0 - channels_per_group = num_channels // num_groups - x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1) - # x now has shape (num_groups, num_frames, channels_per_group) - # subtract the mean so we use the centered, not uncentered, covariance. - # My experience has been that when we "mess with the gradients" like this, - # it's better not do anything that tries to move the mean around, because - # that can easily cause instability. - x = x - x.mean(dim=1, keepdim=True) - # x_covar: (num_groups, channels_per_group, channels_per_group) - x_covar = torch.matmul(x.transpose(1, 2), x) - x_covar_mean_diag = _diag(x_covar).mean() - # the following expression is what we'd get if we took the matrix product - # of each covariance and measured the mean of its trace, i.e. - # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). - x_covarsq_mean_diag = (x_covar ** 2).sum() / (num_groups * channels_per_group) - # this metric will be >= 1.0; the larger it is, the less 'white' the data was. - metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20) - return metric - - -class WhiteningPenaltyFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, - x: Tensor, - module: nn.Module) -> Tensor: - ctx.save_for_backward(x) - ctx.module = module - return x - - @staticmethod - def backward(ctx, - x_grad: Tensor): - x_orig, = ctx.saved_tensors - w = ctx.module - - try: - with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): - x_detached = x_orig.to(torch.float32).detach() - x_detached.requires_grad = True - - metric = _whitening_metric(x_detached, w.num_groups) - - if random.random() < 0.005 or __name__ == "__main__": - logging.info(f"Whitening: name={w.name}, num_groups={w.num_groups}, num_channels={x_orig.shape[-1]}, " - f"metric={metric.item():.2f} vs. limit={float(w.whitening_limit)}") - - if metric < float(w.whitening_limit): - w.prob = w.min_prob - return x_grad, None - else: - w.prob = w.max_prob - metric.backward() - penalty_grad = x_detached.grad - scale = w.grad_scale * (x_grad.to(torch.float32).norm() / - (penalty_grad.norm() + 1.0e-20)) - penalty_grad = penalty_grad * scale - return x_grad + penalty_grad.to(x_grad.dtype), None - except Exception as e: - logging.info(f"Caught exception in Whiten backward: {e}, size={list(x_grad.shape)}, will continue.") - return x_grad, None - - -class Whiten(nn.Module): - def __init__( - self, - num_groups: int, - whitening_limit: FloatLike, - prob: Union[float, Tuple[float,float]], - grad_scale: FloatLike): - """ - Args: - num_groups: the number of groups to divide the channel dim into before - whitening. We will attempt to make the feature covariance - within each group, after mean subtraction, as "white" as possible, - while having the same trace across all groups. - whitening_limit: a value greater than 1.0, that dictates how much - freedom we have to violate the constraints. 1.0 would mean perfectly - white, with exactly the same trace across groups; larger values - give more freedom. E.g. 2.0. - prob: the probability with which we apply the gradient modification - (also affects the grad scale). May be supplied as a float, - or as a pair (min_prob, max_prob) - - grad_scale: determines the scale on the gradient term from this object, - relative to the rest of the gradient on the attention weights. - E.g. 0.02 (you may want to use smaller values than this if prob is large) - """ - super(Whiten, self).__init__() - assert num_groups >= 1 - assert float(whitening_limit) >= 1 - assert grad_scale >= 0 - self.num_groups = num_groups - self.whitening_limit = whitening_limit - self.grad_scale = grad_scale - - if isinstance(prob, float): - prob = (prob, prob) - (self.min_prob, self.max_prob) = prob - assert 0 < self.min_prob <= self.max_prob <= 1 - self.prob = self.max_prob - self.name = None # will be set in training loop - - def forward(self, - x: Tensor) -> Tensor: - """ - In the forward pass, this function just returns the input unmodified. - In the backward pass, it will modify the gradients to ensure that the - distribution in each group has close to (lambda times I) as the covariance - after mean subtraction, with the same lambda across groups. - For whitening_limit > 1, there will be more freedom to violate this - constraint. - - Args: - x: the input of shape (*, num_channels) - - Returns: - x, unmodified. You should make sure - you use the returned value, or the graph will be freed - and nothing will happen in backprop. - """ - grad_scale = float(self.grad_scale) - if not x.requires_grad or random.random() > self.prob or grad_scale == 0: - return _no_op(x) - else: - return WhiteningPenaltyFunction.apply(x, self) - - -class WithLoss(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, y: Tensor, name: str): - ctx.y_shape = y.shape - if random.random() < 0.002 and name is not None: - loss_sum = y.sum().item() - logging.info(f"WithLoss: name={name}, loss-sum={loss_sum:.3e}") - return x - - @staticmethod - def backward(ctx, ans_grad: Tensor): - return ans_grad, torch.ones(ctx.y_shape, - dtype=ans_grad.dtype, - device=ans_grad.device), None - - -def with_loss(x, y, name): - # returns x but adds y.sum() to the loss function. - return WithLoss.apply(x, y, name) - - -class ScaleGradFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, alpha: float) -> Tensor: - ctx.alpha = alpha - return x - - @staticmethod - def backward(ctx, grad: Tensor): - return grad * ctx.alpha, None - - -def scale_grad(x: Tensor, alpha: float): - return ScaleGradFunction.apply(x, alpha) - - -class ScaleGrad(nn.Module): - def __init__(self, alpha: float): - super().__init__() - self.alpha = alpha - - def forward(self, x: Tensor) -> Tensor: - if torch.jit.is_scripting() or not self.training: - return x - return scale_grad(x, self.alpha) - - -class LimitParamValue(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, min: float, max: float): - ctx.save_for_backward(x) - assert max >= min - ctx.min = min - ctx.max = max - return x - - @staticmethod - def backward(ctx, x_grad: Tensor): - x, = ctx.saved_tensors - # where x < ctx.min, ensure all grads are negative (this will tend to make - # x more positive). - x_grad = x_grad * torch.where(torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0) - # where x > ctx.max, ensure all grads are positive (this will tend to make - # x more negative). - x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0) - return x_grad, None, None - - -def limit_param_value(x: Tensor, - min: float, max: float, - prob: float = 0.6, - training: bool = True): - # You apply this to (typically) an nn.Parameter during training to ensure that its - # (elements mostly) stays within a supplied range. This is done by modifying the - # gradients in backprop. - # It's not necessary to do this on every batch: do it only some of the time, - # to save a little time. - if training and random.random() < prob: - return LimitParamValue.apply(x, min, max) - else: - return x - - -def _no_op(x: Tensor) -> Tensor: - if (torch.jit.is_scripting()): - return x - else: - # a no-op function that will have a node in the autograd graph, - # to avoid certain bugs relating to backward hooks - return x.chunk(1, dim=-1)[0] - - -class Identity(torch.nn.Module): - def __init__(self): - super(Identity, self).__init__() - - def forward(self, x): - return _no_op(x) - - -class DoubleSwishFunction(torch.autograd.Function): - """ - double_swish(x) = x * torch.sigmoid(x-1) - - This is a definition, originally motivated by its close numerical - similarity to swish(swish(x)), where swish(x) = x * sigmoid(x). - - Memory-efficient derivative computation: - double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) - double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). - Now, s'(x) = s(x) * (1-s(x)). - double_swish'(x) = x * s'(x) + s(x). - = x * s(x) * (1-s(x)) + s(x). - = double_swish(x) * (1-s(x)) + s(x) - ... so we just need to remember s(x) but not x itself. - """ - - @staticmethod - def forward(ctx, x: Tensor) -> Tensor: - requires_grad = x.requires_grad - if x.dtype == torch.float16: - x = x.to(torch.float32) - - s = torch.sigmoid(x - 1.0) - y = x * s - - if requires_grad: - deriv = (y * (1 - s) + s) - - # notes on derivative of x * sigmoid(x - 1): - # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 - # min \simeq -0.043638. Take floor as -0.044 so it's a lower bund - # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound. - # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which - # floors), should be expectation-preserving. - floor = -0.044 - ceil = 1.2 - d_scaled = ((deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv)) - if __name__ == "__main__": - # for self-testing only. - assert d_scaled.min() >= 0.0 - assert d_scaled.max() < 256.0 - d_int = d_scaled.to(torch.uint8) - ctx.save_for_backward(d_int) - if x.dtype == torch.float16 or torch.is_autocast_enabled(): - y = y.to(torch.float16) - return y - - @staticmethod - def backward(ctx, y_grad: Tensor) -> Tensor: - d, = ctx.saved_tensors - # the same constants as used in forward pass. - floor = -0.043637 - ceil = 1.2 - - d = (d * ((ceil - floor) / 255.0) + floor) - return y_grad * d - - -class DoubleSwish(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x: Tensor) -> Tensor: - """Return double-swish activation function which is an approximation to Swish(Swish(x)), - that we approximate closely with x * sigmoid(x-1). - """ - if torch.jit.is_scripting(): - return x * torch.sigmoid(x - 1.0) - return DoubleSwishFunction.apply(x) - - -# Dropout2 is just like normal dropout, except it supports schedules on the dropout rates. -class Dropout2(nn.Module): - def __init__(self, p: FloatLike): - super().__init__() - self.p = p - - def forward(self, x: Tensor) -> Tensor: - return torch.nn.functional.dropout(x, - p=float(self.p), - training=self.training) - - -class MulForDropout3(torch.autograd.Function): - # returns (x * y * alpha) where alpha is a float and y doesn't require - # grad and is zero-or-one. - @staticmethod - @custom_fwd - def forward(ctx, x, y, alpha): - assert not y.requires_grad - ans = x * y * alpha - ctx.save_for_backward(ans) - ctx.alpha = alpha - return ans - - @staticmethod - @custom_bwd - def backward(ctx, ans_grad): - ans, = ctx.saved_tensors - x_grad = ctx.alpha * ans_grad * (ans != 0) - return x_grad, None, None - - -# Dropout3 is just like normal dropout, except it supports schedules on the dropout rates, -# and it lets you choose one dimension to share the dropout mask over -class Dropout3(nn.Module): - def __init__(self, p: FloatLike, shared_dim: int): - super().__init__() - self.p = p - self.shared_dim = shared_dim - - def forward(self, x: Tensor) -> Tensor: - p = float(self.p) - if not self.training or p == 0: - return _no_op(x) - scale = 1.0 / (1 - p) - rand_shape = list(x.shape) - rand_shape[self.shared_dim] = 1 - mask = torch.rand(*rand_shape, device=x.device) > p - ans = MulForDropout3.apply(x, mask, scale) - return ans - - -class SwooshLFunction(torch.autograd.Function): - """ - swoosh(x) = log(1 + exp(x-4)) - 0.08*x - 0.035 - """ - - @staticmethod - def forward(ctx, x: Tensor) -> Tensor: - requires_grad = x.requires_grad - if x.dtype == torch.float16: - x = x.to(torch.float32) - - zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - - coeff = -0.08 - - with torch.cuda.amp.autocast(enabled=False): - with torch.enable_grad(): - x = x.detach() - x.requires_grad = True - y = torch.logaddexp(zero, x - 4.0) + coeff * x - 0.035 - - if not requires_grad: - return y - - y.backward(gradient = torch.ones_like(y)) - - grad = x.grad - floor = coeff - ceil = 1.0 + coeff + 0.005 - - d_scaled = ((grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(grad)) - if __name__ == "__main__": - # for self-testing only. - assert d_scaled.min() >= 0.0 - assert d_scaled.max() < 256.0 - - d_int = d_scaled.to(torch.uint8) - ctx.save_for_backward(d_int) - if x.dtype == torch.float16 or torch.is_autocast_enabled(): - y = y.to(torch.float16) - return y - - @staticmethod - def backward(ctx, y_grad: Tensor) -> Tensor: - d, = ctx.saved_tensors - # the same constants as used in forward pass. - - coeff = -0.08 - floor = coeff - ceil = 1.0 + coeff + 0.005 - d = (d * ((ceil - floor) / 255.0) + floor) - return (y_grad * d) - - -class SwooshL(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: - """Return Swoosh-L activation. - """ - if torch.jit.is_scripting(): - zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - return torch.logaddexp(zero, x - 4.0) - 0.08 * x - 0.035 - if not x.requires_grad: - return k2.swoosh_l_forward(x) - else: - return k2.swoosh_l(x) - # return SwooshLFunction.apply(x) - - -class SwooshRFunction(torch.autograd.Function): - """ - swoosh(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687 - - derivatives are between -0.08 and 0.92. - """ - - @staticmethod - def forward(ctx, x: Tensor) -> Tensor: - requires_grad = x.requires_grad - - if x.dtype == torch.float16: - x = x.to(torch.float32) - - zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - - with torch.cuda.amp.autocast(enabled=False): - with torch.enable_grad(): - x = x.detach() - x.requires_grad = True - y = torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687 - - if not requires_grad: - return y - y.backward(gradient = torch.ones_like(y)) - - grad = x.grad - floor = -0.08 - ceil = 0.925 - - d_scaled = ((grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(grad)) - if __name__ == "__main__": - # for self-testing only. - assert d_scaled.min() >= 0.0 - assert d_scaled.max() < 256.0 - - d_int = d_scaled.to(torch.uint8) - ctx.save_for_backward(d_int) - if x.dtype == torch.float16 or torch.is_autocast_enabled(): - y = y.to(torch.float16) - return y - - @staticmethod - def backward(ctx, y_grad: Tensor) -> Tensor: - d, = ctx.saved_tensors - # the same constants as used in forward pass. - floor = -0.08 - ceil = 0.925 - d = (d * ((ceil - floor) / 255.0) + floor) - return (y_grad * d) - - -class SwooshR(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: - """Return Swoosh-R activation. - """ - if torch.jit.is_scripting(): - zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - return torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687 - if not x.requires_grad: - return k2.swoosh_r_forward(x) - else: - return k2.swoosh_r(x) - # return SwooshRFunction.apply(x) - - -# simple version of SwooshL that does not redefine the backprop, used in -# ActivationDropoutAndLinearFunction. -def SwooshLForward(x: Tensor): - x_offset = x - 4.0 - log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) - log_sum = torch.where(log_sum == float('inf'), x_offset, log_sum) - return log_sum - 0.08 * x - 0.035 - - -# simple version of SwooshR that does not redefine the backprop, used in -# ActivationDropoutAndLinearFunction. -def SwooshRForward(x: Tensor): - x_offset = x - 1.0 - log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) - log_sum = torch.where(log_sum == float('inf'), x_offset, log_sum) - return log_sum - 0.08 * x - 0.313261687 - - -class ActivationDropoutAndLinearFunction(torch.autograd.Function): - @staticmethod - @custom_fwd - def forward(ctx, - x: Tensor, - weight: Tensor, - bias: Optional[Tensor], - activation: str, - dropout_p: float, - dropout_shared_dim: Optional[int]): - if dropout_p != 0.0: - dropout_shape = list(x.shape) - if dropout_shared_dim is not None: - dropout_shape[dropout_shared_dim] = 1 - # else it won't be very memory efficient. - dropout_mask = ((1.0 / (1.0 - dropout_p)) * - (torch.rand(*dropout_shape, - device=x.device, dtype=x.dtype) > dropout_p)) - else: - dropout_mask = None - - ctx.save_for_backward(x, weight, bias, dropout_mask) - - ctx.activation = activation - - forward_activation_dict = { - 'SwooshL': k2.swoosh_l_forward, - 'SwooshR': k2.swoosh_r_forward - } - # it will raise a KeyError if this fails. This will be an error. We let it - # propagate to the user. - activation_func = forward_activation_dict[activation] - x = activation_func(x) - if dropout_mask is not None: - x = x * dropout_mask - x = torch.nn.functional.linear(x, weight, bias) - return x - - @staticmethod - @custom_bwd - def backward(ctx, ans_grad: Tensor): - saved = ctx.saved_tensors - (x, weight, bias, dropout_mask) = saved - - forward_and_deriv_activation_dict = { - 'SwooshL': k2.swoosh_l_forward_and_deriv, - 'SwooshR': k2.swoosh_r_forward_and_deriv - } - # the following lines a KeyError if the activation is unrecognized. - # This will be an error. We let it propagate to the user. - func = forward_and_deriv_activation_dict[ctx.activation] - - y, func_deriv = func(x) - if dropout_mask is not None: - y = y * dropout_mask - # now compute derivative of y w.r.t. weight and bias.. - # y: (..., in_channels), ans_grad: (..., out_channels), - (out_channels, in_channels) = weight.shape - - in_channels = y.shape[-1] - g = ans_grad.reshape(-1, out_channels) - weight_deriv = torch.matmul(g.t(), - y.reshape(-1, in_channels)) - y_deriv = torch.matmul(ans_grad, weight) - bias_deriv = None if bias is None else g.sum(dim=0) - x_deriv = y_deriv * func_deriv - if dropout_mask is not None: - # order versus func_deriv does not matter - x_deriv = x_deriv * dropout_mask - - return x_deriv, weight_deriv, bias_deriv, None, None, None - - -class ActivationDropoutAndLinear(torch.nn.Module): - """ - This merges an activation function followed by dropout and then a nn.Linear module; - it does so in a memory efficient way so that it only stores the input to the whole - module. If activation == SwooshL and dropout_shared_dim != None, this will be - equivalent to: - nn.Sequential(SwooshL(), - Dropout3(dropout_p, shared_dim=dropout_shared_dim), - ScaledLinear(in_channels, out_channels, bias=bias, - initial_scale=initial_scale)) - If dropout_shared_dim is None, the dropout would be equivalent to - Dropout2(dropout_p). Note: Dropout3 will be more memory efficient as the dropout - mask is smaller. - - Args: - in_channels: number of input channels, e.g. 256 - out_channels: number of output channels, e.g. 256 - bias: if true, have a bias - activation: the activation function, for now just support SwooshL. - dropout_p: the dropout probability or schedule (happens after nonlinearity). - dropout_shared_dim: the dimension, if any, across which the dropout mask is - shared (e.g. the time dimension). If None, this may be less memory - efficient if there are modules before this one that cache the input - for their backprop (e.g. Balancer or Whiten). - """ - def __init__(self, - in_channels: int, - out_channels: int, - bias: bool = True, - activation: str = 'SwooshL', - dropout_p: FloatLike = 0.0, - dropout_shared_dim: Optional[int] = -1, - initial_scale: float = 1.0): - super().__init__() - # create a temporary module of nn.Linear that we'll steal the - # weights and bias from - l = ScaledLinear(in_channels, out_channels, - bias=bias, - initial_scale=initial_scale) - - self.weight = l.weight - # register_parameter properly handles making it a parameter when l.bias - # is None. I think there is some reason for doing it this way rather - # than just setting it to None but I don't know what it is, maybe - # something to do with exporting the module.. - self.register_parameter('bias', l.bias) - - self.activation = activation - self.dropout_p = dropout_p - self.dropout_shared_dim = dropout_shared_dim - - def forward(self, - x: Tensor): - if torch.jit.is_scripting() or torch.jit.is_tracing(): - if self.activation == 'SwooshL': - x = SwooshLForward(x) - elif self.activation == "SwooshR": - x = SwooshRForward(x) - else: - assert False, self.activation - return torch.nn.functional.linear(x, - self.weight, - self.bias) - - return ActivationDropoutAndLinearFunction.apply( - x, self.weight, self.bias, self.activation, - float(self.dropout_p), self.dropout_shared_dim) - - -def convert_num_channels(x: Tensor, num_channels: int) -> Tensor: - if num_channels <= x.shape[-1]: - return x[..., :num_channels] - else: - shape = list(x.shape) - shape[-1] = num_channels - shape[-1] - zeros = torch.zeros(shape, dtype=x.dtype, device=x.device) - return torch.cat((x, zeros), dim=-1) - - -def _test_whiten(): - for proportion in [0.1, 0.5, 10.0]: - logging.info(f"_test_whiten(): proportion = {proportion}") - x = torch.randn(100, 128) - direction = torch.randn(128) - coeffs = torch.randn(100, 1) - x += proportion * direction * coeffs - - x.requires_grad = True - - m = Whiten(1, # num_groups - 5.0, # whitening_limit, - prob=1.0, - grad_scale=0.1) # grad_scale - - for _ in range(4): - y = m(x) - - y_grad = torch.randn_like(x) - y.backward(gradient=y_grad) - - if proportion < 0.2: - assert torch.allclose(x.grad, y_grad) - elif proportion > 1.0: - assert not torch.allclose(x.grad, y_grad) - - -def _test_balancer_sign(): - probs = torch.arange(0, 1, 0.01) - N = 1000 - x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0) - x = x.detach() - x.requires_grad = True - m = Balancer( - probs.numel(), - channel_dim=0, - min_positive=0.05, - max_positive=0.95, - min_abs=0.0, - prob=1.0, - ) - - y_grad = torch.sign(torch.randn(probs.numel(), N)) - - y = m(x) - y.backward(gradient=y_grad) - print("_test_balancer_sign: x = ", x) - print("_test_balancer_sign: y grad = ", y_grad) - print("_test_balancer_sign: x grad = ", x.grad) - - -def _test_balancer_magnitude(): - magnitudes = torch.arange(0, 1, 0.01) - N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( - -1 - ) - x = x.detach() - x.requires_grad = True - m = Balancer( - magnitudes.numel(), - channel_dim=0, - min_positive=0.0, - max_positive=1.0, - min_abs=0.2, - max_abs=0.7, - prob=1.0, - ) - - y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) - - y = m(x) - y.backward(gradient=y_grad) - print("_test_balancer_magnitude: x = ", x) - print("_test_balancer_magnitude: y grad = ", y_grad) - print("_test_balancer_magnitude: x grad = ", x.grad) - - -def _test_double_swish_deriv(): - x = torch.randn(10, 12, dtype=torch.double) * 3.0 - x.requires_grad = True - m = DoubleSwish() - - tol = ((1.2-(-0.043637))/255.0) - torch.autograd.gradcheck(m, x, atol=tol) - - # for self-test. - x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 - x.requires_grad = True - y = m(x) - - -def _test_swooshl_deriv(): - x = torch.randn(10, 12, dtype=torch.double) * 3.0 - x.requires_grad = True - m = SwooshL() - - tol = (1.0 / 255.0) - torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) - - # for self-test. - x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 - x.requires_grad = True - y = m(x) - - -def _test_swooshr_deriv(): - x = torch.randn(10, 12, dtype=torch.double) * 3.0 - x.requires_grad = True - m = SwooshR() - - tol = (1.0 / 255.0) - torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) - - # for self-test. - x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 - x.requires_grad = True - y = m(x) - - -def _test_softmax(): - a = torch.randn(2, 10, dtype=torch.float64) - b = a.clone() - a.requires_grad = True - b.requires_grad = True - a.softmax(dim=1)[:,0].sum().backward() - print("a grad = ", a.grad) - softmax(b, dim=1)[:,0].sum().backward() - print("b grad = ", b.grad) - assert torch.allclose(a.grad, b.grad) - - -def _test_piecewise_linear(): - p = PiecewiseLinear( (0, 10.0) ) - for x in [-100, 0, 100]: - assert p(x) == 10.0 - p = PiecewiseLinear( (0, 10.0), (1, 0.0) ) - for x, y in [ (-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0) ]: - print("x, y = ", x, y) - assert p(x) == y, (x, p(x), y) - - q = PiecewiseLinear((0.5, 15.0), (0.6, 1.0)) - x_vals = [ -1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0 ] - pq = p.max(q) - for x in x_vals: - y1 = max(p(x), q(x)) - y2 = pq(x) - assert abs(y1 - y2) < 0.001 - pq = p.min(q) - for x in x_vals: - y1 = min(p(x), q(x)) - y2 = pq(x) - assert abs(y1 - y2) < 0.001 - pq = p + q - for x in x_vals: - y1 = p(x) + q(x) - y2 = pq(x) - assert abs(y1 - y2) < 0.001 - - -def _test_activation_dropout_and_linear(): - in_channels = 20 - out_channels = 30 - - for bias in [True, False]: - # actually we don't test for dropout_p != 0.0 because forward functions will give - # different answers. This is because we are using the k2 implementation of - # swoosh_l an swoosh_r inside SwooshL() and SwooshR(), and they call randn() - # internally, messing up the random state. - for dropout_p in [0.0]: - for activation in ['SwooshL', 'SwooshR']: - m1 = nn.Sequential(SwooshL() if activation == 'SwooshL' else SwooshR(), - Dropout3(p=dropout_p, shared_dim=-1), - ScaledLinear(in_channels, out_channels, bias=bias, - initial_scale=0.5)) - m2 = ActivationDropoutAndLinear(in_channels, out_channels, - bias=bias, initial_scale=0.5, - activation=activation, - dropout_p=dropout_p) - with torch.no_grad(): - m2.weight[:] = m1[2].weight - if bias: - m2.bias[:] = m1[2].bias - # make sure forward gives same result. - x1 = torch.randn(10, in_channels) - x1.requires_grad = True - - # TEMP. - assert torch.allclose(SwooshRFunction.apply(x1), - SwooshRForward(x1), - atol=1.0e-03) - - x2 = x1.clone().detach() - x2.requires_grad = True - seed = 10 - torch.manual_seed(seed) - y1 = m1(x1) - y_grad = torch.randn_like(y1) - y1.backward(gradient=y_grad) - torch.manual_seed(seed) - y2 = m2(x2) - y2.backward(gradient=y_grad) - - print(f"bias = {bias}, dropout_p = {dropout_p}, activation = {activation}") - print("y1 = ", y1) - print("y2 = ", y2) - assert torch.allclose(y1, y2, atol=0.02) - assert torch.allclose(m1[2].weight.grad, m2.weight.grad, - atol=1.0e-05) - if bias: - assert torch.allclose(m1[2].bias.grad, m2.bias.grad, - atol=1.0e-05) - print("x1.grad = ", x1.grad) - print("x2.grad = ", x2.grad) - - def isclose(a, b): - # return true if cosine similarity is > 0.9. - return (a * b).sum() > 0.9 * ((a**2).sum() * (b**2).sum()).sqrt() - # the SwooshL() implementation has a noisy gradient due to 1-byte - # storage of it. - assert isclose(x1.grad, x2.grad) - - -if __name__ == "__main__": - logging.getLogger().setLevel(logging.INFO) - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - _test_piecewise_linear() - _test_softmax() - _test_whiten() - _test_balancer_sign() - _test_balancer_magnitude() - _test_double_swish_deriv() - _test_swooshr_deriv() - _test_swooshl_deriv() - _test_activation_dropout_and_linear() diff --git a/egs/iwslt22_ta/ASR/zipformer/scaling.py b/egs/iwslt22_ta/ASR/zipformer/scaling.py new file mode 120000 index 000000000..6f398f431 --- /dev/null +++ b/egs/iwslt22_ta/ASR/zipformer/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/zipformer/scaling_converter.py b/egs/iwslt22_ta/ASR/zipformer/scaling_converter.py deleted file mode 100644 index 683a03461..000000000 --- a/egs/iwslt22_ta/ASR/zipformer/scaling_converter.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file replaces various modules in a model. -Specifically, ActivationBalancer is replaced with an identity operator; -Whiten is also replaced with an identity operator; -BasicNorm is replaced by a module with `exp` removed. -""" - -import copy -from typing import List, Tuple - -import torch -import torch.nn as nn -from scaling import Balancer, Dropout3, ScaleGrad, Whiten - - -# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa -# get_submodule was added to nn.Module at v1.9.0 -def get_submodule(model, target): - if target == "": - return model - atoms: List[str] = target.split(".") - mod: torch.nn.Module = model - for item in atoms: - if not hasattr(mod, item): - raise AttributeError( - mod._get_name() + " has no " "attribute `" + item + "`" - ) - mod = getattr(mod, item) - if not isinstance(mod, torch.nn.Module): - raise AttributeError("`" + item + "` is not " "an nn.Module") - return mod - - -def convert_scaled_to_non_scaled( - model: nn.Module, - inplace: bool = False, - is_pnnx: bool = False, -): - """ - Args: - model: - The model to be converted. - inplace: - If True, the input model is modified inplace. - If False, the input model is copied and we modify the copied version. - is_pnnx: - True if we are going to export the model for PNNX. - Return: - Return a model without scaled layers. - """ - if not inplace: - model = copy.deepcopy(model) - - d = {} - for name, m in model.named_modules(): - if isinstance(m, (Balancer, Dropout3, ScaleGrad, Whiten)): - d[name] = nn.Identity() - - for k, v in d.items(): - if "." in k: - parent, child = k.rsplit(".", maxsplit=1) - setattr(get_submodule(model, parent), child, v) - else: - setattr(model, k, v) - - return model diff --git a/egs/iwslt22_ta/ASR/zipformer/scaling_converter.py b/egs/iwslt22_ta/ASR/zipformer/scaling_converter.py new file mode 120000 index 000000000..b0ecee05e --- /dev/null +++ b/egs/iwslt22_ta/ASR/zipformer/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/zipformer/streaming_beam_search.py b/egs/iwslt22_ta/ASR/zipformer/streaming_beam_search.py deleted file mode 100644 index e6e0fb1c8..000000000 --- a/egs/iwslt22_ta/ASR/zipformer/streaming_beam_search.py +++ /dev/null @@ -1,282 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings -from typing import List - -import k2 -import torch -import torch.nn as nn -from beam_search import Hypothesis, HypothesisList, get_hyps_shape -from decode_stream import DecodeStream - -from icefall.decode import one_best_decoding -from icefall.utils import get_texts - - -def greedy_search( - model: nn.Module, - encoder_out: torch.Tensor, - streams: List[DecodeStream], -) -> None: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C), where N >= 1. - streams: - A list of Stream objects. - """ - assert len(streams) == encoder_out.size(0) - assert encoder_out.ndim == 3 - - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - device = model.device - T = encoder_out.size(1) - - decoder_input = torch.tensor( - [stream.hyp[-context_size:] for stream in streams], - device=device, - dtype=torch.int64, - ) - # decoder_out is of shape (N, 1, decoder_out_dim) - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - for t in range(T): - # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) - current_encoder_out = encoder_out[:, t : t + 1, :] # noqa - - logits = model.joiner( - current_encoder_out.unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - # logits'shape (batch_size, vocab_size) - logits = logits.squeeze(1).squeeze(1) - - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v != blank_id: - streams[i].hyp.append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = torch.tensor( - [stream.hyp[-context_size:] for stream in streams], - device=device, - dtype=torch.int64, - ) - decoder_out = model.decoder( - decoder_input, - need_pad=False, - ) - decoder_out = model.joiner.decoder_proj(decoder_out) - - -def modified_beam_search( - model: nn.Module, - encoder_out: torch.Tensor, - streams: List[DecodeStream], - num_active_paths: int = 4, -) -> None: - """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. - - Args: - model: - The RNN-T model. - encoder_out: - A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of - the encoder model. - streams: - A list of stream objects. - num_active_paths: - Number of active paths during the beam search. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert len(streams) == encoder_out.size(0) - - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - device = next(model.parameters()).device - batch_size = len(streams) - T = encoder_out.size(1) - - B = [stream.hyps for stream in streams] - - for t in range(T): - current_encoder_out = encoder_out[:, t].unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.stack( - [hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0 - ) # (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out is of shape (num_hyps, 1, 1, decoder_output_dim) - - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor - # as index, so we use `to(torch.int64)` below. - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, encoder_out_dim) - - logits = model.joiner(current_encoder_out, decoder_out, project_input=False) - # logits is of shape (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) - - log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_ys = hyp.ys[:] - new_token = topk_token_indexes[k] - if new_token != blank_id: - new_ys.append(new_token) - - new_log_prob = topk_log_probs[k] - new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) - B[i].add(new_hyp) - - for i in range(batch_size): - streams[i].hyps = B[i] - - -def fast_beam_search_one_best( - model: nn.Module, - encoder_out: torch.Tensor, - processed_lens: torch.Tensor, - streams: List[DecodeStream], - beam: float, - max_states: int, - max_contexts: int, -) -> None: - """It limits the maximum number of symbols per frame to 1. - - A lattice is first generated by Fsa-based beam search, then we get the - recognition by applying shortest path on the lattice. - - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - processed_lens: - A tensor of shape (N,) containing the number of processed frames - in `encoder_out` before padding. - streams: - A list of stream objects. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - """ - assert encoder_out.ndim == 3 - B, T, C = encoder_out.shape - assert B == len(streams) - - context_size = model.decoder.context_size - vocab_size = model.decoder.vocab_size - - config = k2.RnntDecodingConfig( - vocab_size=vocab_size, - decoder_history_len=context_size, - beam=beam, - max_contexts=max_contexts, - max_states=max_states, - ) - individual_streams = [] - for i in range(B): - individual_streams.append(streams[i].rnnt_decoding_stream) - decoding_streams = k2.RnntDecodingStreams(individual_streams, config) - - for t in range(T): - # shape is a RaggedShape of shape (B, context) - # contexts is a Tensor of shape (shape.NumElements(), context_size) - shape, contexts = decoding_streams.get_contexts() - # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 - contexts = contexts.to(torch.int64) - # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) - decoder_out = model.decoder(contexts, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - # current_encoder_out is of shape - # (shape.NumElements(), 1, joiner_dim) - # fmt: off - current_encoder_out = torch.index_select( - encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) - ) - # fmt: on - logits = model.joiner( - current_encoder_out.unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - logits = logits.squeeze(1).squeeze(1) - log_probs = logits.log_softmax(dim=-1) - decoding_streams.advance(log_probs) - - decoding_streams.terminate_and_flush_to_streams() - - lattice = decoding_streams.format_output(processed_lens.tolist()) - best_path = one_best_decoding(lattice) - hyp_tokens = get_texts(best_path) - - for i in range(B): - streams[i].hyp = hyp_tokens[i] diff --git a/egs/iwslt22_ta/ASR/zipformer/streaming_beam_search.py b/egs/iwslt22_ta/ASR/zipformer/streaming_beam_search.py new file mode 120000 index 000000000..b1ed54557 --- /dev/null +++ b/egs/iwslt22_ta/ASR/zipformer/streaming_beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/streaming_beam_search.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/zipformer/streaming_decode.py b/egs/iwslt22_ta/ASR/zipformer/streaming_decode.py deleted file mode 100755 index c2d58cb1e..000000000 --- a/egs/iwslt22_ta/ASR/zipformer/streaming_decode.py +++ /dev/null @@ -1,876 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang, -# Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Usage: -./zipformer/streaming_decode.py \ - --epoch 28 \ - --avg 15 \ - --causal 1 \ - --chunk-size 32 \ - --left-context-frames 256 \ - --exp-dir ./zipformer/exp \ - --decoding-method greedy_search \ - --num-decode-streams 2000 -""" - -import argparse -import logging -import math -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import numpy as np -import sentencepiece as spm -import torch -from asr_datamodule import LibriSpeechAsrDataModule -from decode_stream import DecodeStream -from kaldifeat import Fbank, FbankOptions -from lhotse import CutSet -from streaming_beam_search import ( - fast_beam_search_one_best, - greedy_search, - modified_beam_search, -) -from torch import Tensor, nn -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import ( - AttributeDict, - make_pad_mask, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 0. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Supported decoding methods are: - greedy_search - modified_beam_search - fast_beam_search - """, - ) - - parser.add_argument( - "--num_active_paths", - type=int, - default=4, - help="""An interger indicating how many candidates we will keep for each - frame. Used only when --decoding-method is modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=32, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--num-decode-streams", - type=int, - default=2000, - help="The number of streams that can be decoded parallel.", - ) - - add_model_arguments(parser) - - return parser - - -def get_init_states( - model: nn.Module, - batch_size: int = 1, - device: torch.device = torch.device("cpu"), -) -> List[torch.Tensor]: - """ - Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] - is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - states[-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - states[-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - """ - states = model.encoder.get_init_states(batch_size, device) - - embed_states = model.encoder_embed.get_init_states(batch_size, device) - states.append(embed_states) - - processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) - states.append(processed_lens) - - return states - - -def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: - """Stack list of zipformer states that correspond to separate utterances - into a single emformer state, so that it can be used as an input for - zipformer when those utterances are formed into a batch. - - Args: - state_list: - Each element in state_list corresponding to the internal state - of the zipformer model for a single utterance. For element-n, - state_list[n] is a list of cached tensors of all encoder layers. For layer-i, - state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, - cached_val2, cached_conv1, cached_conv2). - state_list[n][-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - state_list[n][-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - - Note: - It is the inverse of :func:`unstack_states`. - """ - batch_size = len(state_list) - assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0]) - tot_num_layers = (len(state_list[0]) - 2) // 6 - - batch_states = [] - for layer in range(tot_num_layers): - layer_offset = layer * 6 - # cached_key: (left_context_len, batch_size, key_dim) - cached_key = torch.cat( - [state_list[i][layer_offset] for i in range(batch_size)], dim=1 - ) - # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) - cached_nonlin_attn = torch.cat( - [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1 - ) - # cached_val1: (left_context_len, batch_size, value_dim) - cached_val1 = torch.cat( - [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1 - ) - # cached_val2: (left_context_len, batch_size, value_dim) - cached_val2 = torch.cat( - [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1 - ) - # cached_conv1: (#batch, channels, left_pad) - cached_conv1 = torch.cat( - [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0 - ) - # cached_conv2: (#batch, channels, left_pad) - cached_conv2 = torch.cat( - [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0 - ) - batch_states += [ - cached_key, - cached_nonlin_attn, - cached_val1, - cached_val2, - cached_conv1, - cached_conv2, - ] - - cached_embed_left_pad = torch.cat( - [state_list[i][-2] for i in range(batch_size)], dim=0 - ) - batch_states.append(cached_embed_left_pad) - - processed_lens = torch.cat( - [state_list[i][-1] for i in range(batch_size)], dim=0 - ) - batch_states.append(processed_lens) - - return batch_states - - -def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: - """Unstack the zipformer state corresponding to a batch of utterances - into a list of states, where the i-th entry is the state from the i-th - utterance in the batch. - - Note: - It is the inverse of :func:`stack_states`. - - Args: - batch_states: A list of cached tensors of all encoder layers. For layer-i, - states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, - cached_conv1, cached_conv2). - state_list[-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - states[-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - - Returns: - state_list: A list of list. Each element in state_list corresponding to the internal state - of the zipformer model for a single utterance. - """ - assert (len(batch_states) - 2) % 6 == 0, len(batch_states) - tot_num_layers = (len(batch_states) - 2) // 6 - - processed_lens = batch_states[-1] - batch_size = processed_lens.shape[0] - - state_list = [[] for _ in range(batch_size)] - - for layer in range(tot_num_layers): - layer_offset = layer * 6 - # cached_key: (left_context_len, batch_size, key_dim) - cached_key_list = batch_states[layer_offset].chunk( - chunks=batch_size, dim=1 - ) - # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) - cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( - chunks=batch_size, dim=1 - ) - # cached_val1: (left_context_len, batch_size, value_dim) - cached_val1_list = batch_states[layer_offset + 2].chunk( - chunks=batch_size, dim=1 - ) - # cached_val2: (left_context_len, batch_size, value_dim) - cached_val2_list = batch_states[layer_offset + 3].chunk( - chunks=batch_size, dim=1 - ) - # cached_conv1: (#batch, channels, left_pad) - cached_conv1_list = batch_states[layer_offset + 4].chunk( - chunks=batch_size, dim=0 - ) - # cached_conv2: (#batch, channels, left_pad) - cached_conv2_list = batch_states[layer_offset + 5].chunk( - chunks=batch_size, dim=0 - ) - for i in range(batch_size): - state_list[i] += [ - cached_key_list[i], - cached_nonlin_attn_list[i], - cached_val1_list[i], - cached_val2_list[i], - cached_conv1_list[i], - cached_conv2_list[i], - ] - - cached_embed_left_pad_list = batch_states[-2].chunk( - chunks=batch_size, dim=0 - ) - for i in range(batch_size): - state_list[i].append(cached_embed_left_pad_list[i]) - - processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0) - for i in range(batch_size): - state_list[i].append(processed_lens_list[i]) - - return state_list - - -def streaming_forward( - features: Tensor, - feature_lens: Tensor, - model: nn.Module, - states: List[Tensor], - chunk_size: int, - left_context_len: int, -) -> Tuple[Tensor, Tensor, List[Tensor]]: - """ - Returns encoder outputs, output lengths, and updated states. - """ - cached_embed_left_pad = states[-2] - ( - x, - x_lens, - new_cached_embed_left_pad, - ) = model.encoder_embed.streaming_forward( - x=features, - x_lens=feature_lens, - cached_left_pad=cached_embed_left_pad, - ) - assert x.size(1) == chunk_size, (x.size(1), chunk_size) - - src_key_padding_mask = make_pad_mask(x_lens) - - # processed_mask is used to mask out initial states - processed_mask = torch.arange(left_context_len, device=x.device).expand( - x.size(0), left_context_len - ) - processed_lens = states[-1] # (batch,) - # (batch, left_context_size) - processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) - # Update processed lengths - new_processed_lens = processed_lens + x_lens - - # (batch, left_context_size + chunk_size) - src_key_padding_mask = torch.cat( - [processed_mask, src_key_padding_mask], dim=1 - ) - - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_states = states[:-2] - ( - encoder_out, - encoder_out_lens, - new_encoder_states, - ) = model.encoder.streaming_forward( - x=x, - x_lens=x_lens, - states=encoder_states, - src_key_padding_mask=src_key_padding_mask, - ) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - new_states = new_encoder_states + [ - new_cached_embed_left_pad, - new_processed_lens, - ] - return encoder_out, encoder_out_lens, new_states - - -def decode_one_chunk( - params: AttributeDict, - model: nn.Module, - decode_streams: List[DecodeStream], -) -> List[int]: - """Decode one chunk frames of features for each decode_streams and - return the indexes of finished streams in a List. - - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - decode_streams: - A List of DecodeStream, each belonging to a utterance. - Returns: - Return a List containing which DecodeStreams are finished. - """ - device = model.device - chunk_size = int(params.chunk_size) - left_context_len = int(params.left_context_frames) - - features = [] - feature_lens = [] - states = [] - processed_lens = [] # Used in fast-beam-search - - for stream in decode_streams: - feat, feat_len = stream.get_feature_frames(chunk_size * 2) - features.append(feat) - feature_lens.append(feat_len) - states.append(stream.states) - processed_lens.append(stream.done_frames) - - feature_lens = torch.tensor(feature_lens, device=device) - features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) - - # Make sure the length after encoder_embed is at least 1. - # The encoder_embed subsample features (T - 7) // 2 - # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling - tail_length = chunk_size * 2 + 7 + 2 * 3 - if features.size(1) < tail_length: - pad_length = tail_length - features.size(1) - feature_lens += pad_length - features = torch.nn.functional.pad( - features, - (0, 0, 0, pad_length), - mode="constant", - value=LOG_EPS, - ) - - states = stack_states(states) - - encoder_out, encoder_out_lens, new_states = streaming_forward( - features=features, - feature_lens=feature_lens, - model=model, - states=states, - chunk_size=chunk_size, - left_context_len=left_context_len, - ) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) - elif params.decoding_method == "fast_beam_search": - processed_lens = torch.tensor(processed_lens, device=device) - processed_lens = processed_lens + encoder_out_lens - fast_beam_search_one_best( - model=model, - encoder_out=encoder_out, - processed_lens=processed_lens, - streams=decode_streams, - beam=params.beam, - max_states=params.max_states, - max_contexts=params.max_contexts, - ) - elif params.decoding_method == "modified_beam_search": - modified_beam_search( - model=model, - streams=decode_streams, - encoder_out=encoder_out, - num_active_paths=params.num_active_paths, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - - states = unstack_states(new_states) - - finished_streams = [] - for i in range(len(decode_streams)): - decode_streams[i].states = states[i] - decode_streams[i].done_frames += encoder_out_lens[i] - if decode_streams[i].done: - finished_streams.append(i) - - return finished_streams - - -def decode_dataset( - cuts: CutSet, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: - """Decode dataset. - - Args: - cuts: - Lhotse Cutset containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - device = model.device - - opts = FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = 16000 - opts.mel_opts.num_bins = 80 - - log_interval = 100 - - decode_results = [] - # Contain decode streams currently running. - decode_streams = [] - for num, cut in enumerate(cuts): - # each utterance has a DecodeStream. - initial_states = get_init_states( - model=model, batch_size=1, device=device - ) - decode_stream = DecodeStream( - params=params, - cut_id=cut.id, - initial_states=initial_states, - decoding_graph=decoding_graph, - device=device, - ) - - audio: np.ndarray = cut.load_audio() - # audio.shape: (1, num_samples) - assert len(audio.shape) == 2 - assert audio.shape[0] == 1, "Should be single channel" - assert audio.dtype == np.float32, audio.dtype - - # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" - - samples = torch.from_numpy(audio).squeeze(0) - - fbank = Fbank(opts) - feature = fbank(samples.to(device)) - decode_stream.set_features(feature, tail_pad_len=30) - decode_stream.ground_truth = cut.supervisions[0].text - - decode_streams.append(decode_stream) - - while len(decode_streams) >= params.num_decode_streams: - finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams - ) - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - decode_streams[i].ground_truth.split(), - sp.decode(decode_streams[i].decoding_result()).split(), - ) - ) - del decode_streams[i] - - if num % log_interval == 0: - logging.info(f"Cuts processed until now is {num}.") - - # decode final chunks of last sequences - while len(decode_streams): - finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams - ) - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - decode_streams[i].ground_truth.split(), - sp.decode(decode_streams[i].decoding_result()).split(), - ) - ) - del decode_streams[i] - - if params.decoding_method == "greedy_search": - key = "greedy_search" - elif params.decoding_method == "fast_beam_search": - key = ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ) - elif params.decoding_method == "modified_beam_search": - key = f"num_active_paths_{params.num_active_paths}" - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - return {key: decode_results} - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - params.res_dir = params.exp_dir / "streaming" / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - assert params.causal, params.causal - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" - - # for fast_beam_search - if params.decoding_method == "fast_beam_search": - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - model.device = device - - decoding_graph = None - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - librispeech = LibriSpeechAsrDataModule(args) - - test_clean_cuts = librispeech.test_clean_cuts() - test_other_cuts = librispeech.test_other_cuts() - - test_sets = ["test-clean", "test-other"] - test_cuts = [test_clean_cuts, test_other_cuts] - - for test_set, test_cut in zip(test_sets, test_cuts): - results_dict = decode_dataset( - cuts=test_cut, - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/iwslt22_ta/ASR/zipformer/streaming_decode.py b/egs/iwslt22_ta/ASR/zipformer/streaming_decode.py new file mode 120000 index 000000000..13fd02a78 --- /dev/null +++ b/egs/iwslt22_ta/ASR/zipformer/streaming_decode.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/streaming_decode.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/zipformer/subsampling.py b/egs/iwslt22_ta/ASR/zipformer/subsampling.py deleted file mode 100644 index 47403f13c..000000000 --- a/egs/iwslt22_ta/ASR/zipformer/subsampling.py +++ /dev/null @@ -1,407 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Daniel Povey, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Tuple -import warnings - -import torch -from torch import Tensor, nn -from scaling import ( - Balancer, - BiasNorm, - Dropout3, - FloatLike, - Optional, - ScaledConv2d, - ScaleGrad, - ScheduledFloat, - SwooshL, - SwooshR, - Whiten, -) - - -class ConvNeXt(nn.Module): - """ - Our interpretation of the ConvNeXt module as used in https://arxiv.org/pdf/2206.14747.pdf - """ - - def __init__( - self, - channels: int, - hidden_ratio: int = 3, - kernel_size: Tuple[int, int] = (7, 7), - layerdrop_rate: FloatLike = None, - ): - super().__init__() - self.padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2) - hidden_channels = channels * hidden_ratio - if layerdrop_rate is None: - layerdrop_rate = ScheduledFloat((0.0, 0.2), (20000.0, 0.015)) - self.layerdrop_rate = layerdrop_rate - - self.depthwise_conv = nn.Conv2d( - in_channels=channels, - out_channels=channels, - groups=channels, - kernel_size=kernel_size, - padding=self.padding, - ) - - self.pointwise_conv1 = nn.Conv2d( - in_channels=channels, out_channels=hidden_channels, kernel_size=1 - ) - - self.hidden_balancer = Balancer( - hidden_channels, - channel_dim=1, - min_positive=0.3, - max_positive=1.0, - min_abs=0.75, - max_abs=5.0, - ) - - self.activation = SwooshL() - self.pointwise_conv2 = ScaledConv2d( - in_channels=hidden_channels, - out_channels=channels, - kernel_size=1, - initial_scale=0.01, - ) - - self.out_balancer = Balancer( - channels, - channel_dim=1, - min_positive=0.4, - max_positive=0.6, - min_abs=1.0, - max_abs=6.0, - ) - self.out_whiten = Whiten( - num_groups=1, - whitening_limit=5.0, - prob=(0.025, 0.25), - grad_scale=0.01, - ) - - def forward(self, x: Tensor) -> Tensor: - if torch.jit.is_scripting() or not self.training: - return self.forward_internal(x) - layerdrop_rate = float(self.layerdrop_rate) - - if layerdrop_rate != 0.0: - batch_size = x.shape[0] - mask = ( - torch.rand( - (batch_size, 1, 1, 1), dtype=x.dtype, device=x.device - ) - > layerdrop_rate - ) - else: - mask = None - # turns out this caching idea does not work with --world-size > 1 - # return caching_eval(self.forward_internal, x, mask) - return self.forward_internal(x, mask) - - def forward_internal( - self, x: Tensor, layer_skip_mask: Optional[Tensor] = None - ) -> Tensor: - """ - x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs) - - The returned value has the same shape as x. - """ - bypass = x - x = self.depthwise_conv(x) - x = self.pointwise_conv1(x) - x = self.hidden_balancer(x) - x = self.activation(x) - x = self.pointwise_conv2(x) - - if layer_skip_mask is not None: - x = x * layer_skip_mask - - x = bypass + x - x = self.out_balancer(x) - x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last - x = self.out_whiten(x) - x = x.transpose(1, 3) # (N, C, H, W) - - return x - - def streaming_forward( - self, - x: Tensor, - cached_left_pad: Tensor, - ) -> Tuple[Tensor, Tensor]: - """ - Args: - x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs) - cached_left_pad: (batch_size, num_channels, left_pad, num_freqs) - - Returns: - - The returned value has the same shape as x. - - Updated cached_left_pad. - """ - padding = self.padding - - # The length without right padding for depth-wise conv - T = x.size(2) - padding[0] - - bypass = x[:, :, :T, :] - - # Pad left side - assert cached_left_pad.size(2) == padding[0], ( - cached_left_pad.size(2), - padding[0], - ) - x = torch.cat([cached_left_pad, x], dim=2) - # Update cached left padding - cached_left_pad = x[:, :, T : padding[0] + T, :] - - # depthwise_conv - x = torch.nn.functional.conv2d( - x, - weight=self.depthwise_conv.weight, - bias=self.depthwise_conv.bias, - padding=(0, padding[1]), - groups=self.depthwise_conv.groups, - ) - x = self.pointwise_conv1(x) - x = self.hidden_balancer(x) - x = self.activation(x) - x = self.pointwise_conv2(x) - - x = bypass + x - return x, cached_left_pad - - -class Conv2dSubsampling(nn.Module): - """Convolutional 2D subsampling (to 1/2 length). - - Convert an input of shape (N, T, idim) to an output - with shape (N, T', odim), where - T' = (T-3)//2 - 2 == (T-7)//2 - - It is based on - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - layer1_channels: int = 8, - layer2_channels: int = 32, - layer3_channels: int = 128, - dropout: FloatLike = 0.1, - ) -> None: - """ - Args: - in_channels: - Number of channels in. The input shape is (N, T, in_channels). - Caution: It requires: T >=7, in_channels >=7 - out_channels - Output dim. The output shape is (N, (T-3)//2, out_channels) - layer1_channels: - Number of channels in layer1 - layer1_channels: - Number of channels in layer2 - bottleneck: - bottleneck dimension for 1d squeeze-excite - """ - assert in_channels >= 7 - super().__init__() - - # The ScaleGrad module is there to prevent the gradients - # w.r.t. the weight or bias of the first Conv2d module in self.conv from - # exceeding the range of fp16 when using automatic mixed precision (amp) - # training. (The second one is necessary to stop its bias from getting - # a too-large gradient). - - self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, - out_channels=layer1_channels, - kernel_size=3, - padding=(0, 1), # (time, freq) - ), - ScaleGrad(0.2), - Balancer(layer1_channels, channel_dim=1, max_abs=1.0), - SwooshR(), - nn.Conv2d( - in_channels=layer1_channels, - out_channels=layer2_channels, - kernel_size=3, - stride=2, - padding=0, - ), - Balancer(layer2_channels, channel_dim=1, max_abs=4.0), - SwooshR(), - nn.Conv2d( - in_channels=layer2_channels, - out_channels=layer3_channels, - kernel_size=3, - stride=(1, 2), # (time, freq) - ), - Balancer(layer3_channels, channel_dim=1, max_abs=4.0), - SwooshR(), - ) - - # just one convnext layer - self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7)) - - self.out_width = (((in_channels - 1) // 2) - 1) // 2 - self.layer3_channels = layer3_channels - - self.out = nn.Linear(self.out_width * layer3_channels, out_channels) - # use a larger than normal grad_scale on this whitening module; there is - # only one such module, so there is not a concern about adding together - # many copies of this extra gradient term. - self.out_whiten = Whiten( - num_groups=1, - whitening_limit=ScheduledFloat( - (0.0, 4.0), (20000.0, 8.0), default=4.0 - ), - prob=(0.025, 0.25), - grad_scale=0.02, - ) - - # max_log_eps=0.0 is to prevent both eps and the output of self.out from - # getting large, there is an unnecessary degree of freedom. - self.out_norm = BiasNorm(out_channels) - self.dropout = Dropout3(dropout, shared_dim=1) - - def forward( - self, x: torch.Tensor, x_lens: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Subsample x. - - Args: - x: - Its shape is (N, T, idim). - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - - Returns: - - a tensor of shape (N, ((T-1)//2 - 1)//2, odim) - - output lengths, of shape (batch_size,) - """ - # On entry, x is (N, T, idim) - x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) - # scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision) - # training, since the weights in the first convolution are otherwise the limiting factor for getting infinite - # gradients. - x = self.conv(x) - x = self.convnext(x) - - # Now x is of shape (N, odim, ((T-3)//2 - 1)//2, ((idim-1)//2 - 1)//2) - b, c, t, f = x.size() - - x = x.transpose(1, 2).reshape(b, t, c * f) - # now x: (N, ((T-1)//2 - 1))//2, out_width * layer3_channels)) - - x = self.out(x) - # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) - x = self.out_whiten(x) - x = self.out_norm(x) - x = self.dropout(x) - - if torch.jit.is_scripting(): - x_lens = (x_lens - 7) // 2 - else: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - x_lens = (x_lens - 7) // 2 - assert x.size(1) == x_lens.max().item() - - return x, x_lens - - def streaming_forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - cached_left_pad: Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Subsample x. - - Args: - x: - Its shape is (N, T, idim). - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - - Returns: - - a tensor of shape (N, ((T-1)//2 - 1)//2, odim) - - output lengths, of shape (batch_size,) - - updated cache - """ - # On entry, x is (N, T, idim) - x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) - - # T' = (T-7)//2 - x = self.conv(x) - - # T' = (T-7)//2-3 - x, cached_left_pad = self.convnext.streaming_forward( - x, cached_left_pad=cached_left_pad - ) - - # Now x is of shape (N, odim, T', ((idim-1)//2 - 1)//2) - b, c, t, f = x.size() - - x = x.transpose(1, 2).reshape(b, t, c * f) - # now x: (N, T', out_width * layer3_channels)) - - x = self.out(x) - # Now x is of shape (N, T', odim) - x = self.out_norm(x) - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - assert self.convnext.padding[0] == 3 - # The ConvNeXt module needs 3 frames of right padding after subsampling - x_lens = (x_lens - 7) // 2 - 3 - else: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - # The ConvNeXt module needs 3 frames of right padding after subsampling - assert self.convnext.padding[0] == 3 - x_lens = (x_lens - 7) // 2 - 3 - - assert x.size(1) == x_lens.max().item() - - return x, x_lens, cached_left_pad - - @torch.jit.export - def get_init_states( - self, - batch_size: int = 1, - device: torch.device = torch.device("cpu"), - ) -> Tensor: - """Get initial states for Conv2dSubsampling module. - It is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - """ - left_pad = self.convnext.padding[0] - freq = self.out_width - channels = self.layer3_channels - cached_embed_left_pad = torch.zeros( - batch_size, channels, left_pad, freq - ).to(device) - - return cached_embed_left_pad diff --git a/egs/iwslt22_ta/ASR/zipformer/subsampling.py b/egs/iwslt22_ta/ASR/zipformer/subsampling.py new file mode 120000 index 000000000..01ae9002c --- /dev/null +++ b/egs/iwslt22_ta/ASR/zipformer/subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ASR/zipformer/zipformer.py b/egs/iwslt22_ta/ASR/zipformer/zipformer.py deleted file mode 100644 index 8d90198fd..000000000 --- a/egs/iwslt22_ta/ASR/zipformer/zipformer.py +++ /dev/null @@ -1,2237 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -import math -import warnings -from typing import List, Optional, Tuple, Union -import logging -import torch -import random -from encoder_interface import EncoderInterface -from scaling import ( - Balancer, - BiasNorm, - Dropout2, - ChunkCausalDepthwiseConv1d, - ActivationDropoutAndLinear, - ScaledLinear, # not as in other dirs.. just scales down initial parameter values. - Whiten, - Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. - penalize_abs_values_gt, - softmax, - ScheduledFloat, - FloatLike, - limit_param_value, - convert_num_channels, -) -from torch import Tensor, nn - - -class Zipformer2(EncoderInterface): - """ - Args: - - Note: all "int or Tuple[int]" arguments below will be treated as lists of the same length - as downsampling_factor if they are single ints or one-element tuples. The length of - downsampling_factor defines the number of stacks. - - output_downsampling_factor (int): how much to downsample at the output. Note: - we also downsample by a factor of 2 in the Conv2dSubsampling encoder. - You should probably leave this at 2. - downsampling_factor (Tuple[int]): downsampling factor for each encoder stack. - Note: this is in addition to the downsampling factor of 2 that is applied in - the frontend (self.encoder_embed). - encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per - encoder stack. - num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack - encoder_unmasked_dim (int or Tuple[int]): unmasked dimension in each of - the encoder stacks for purposes of per-frame dropout (recommend 256 for - now). - query_head_dim (int or Tuple[int]): dimension of query and key per attention - head: per stack, if a tuple.. - pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection per - attention head - value_head_dim (int or Tuple[int]): dimension of value in each attention head - num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism. - Must be at least 4. - feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules - cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module - - pos_dim (int): the dimension of each positional-encoding vector prior to projection, - e.g. 128. - - dropout (float): dropout rate - warmup_batches (float): number of batches to warm up over; this controls - dropout of encoder layers. - causal (bool): if True, support chunkwise causal convolution. This should - not hurt WER as no modeling power is lost, but the convolution modules will be - slightly slower and use more memory. Enables use of the chunk_size and - left_context_chunks options in forward(), which simulates streaming - decoding. - chunk_size: (list of int): only set this to other than [-1] if causal; - the chunk size will be randomly chosen from this list. -1 means no chunking. - left_context_frames: (list of int): determines the number of left- - context chunks for causal training; will be rounded to a number of - chunks. Must not be less than cnn_module_kernel (after factoring in - rounding and downsampling); an error will be thrown if this is violated. - """ - def __init__( - self, - output_downsampling_factor: int = 2, - downsampling_factor: Tuple[int] = (2, 4), - encoder_dim: Union[int, Tuple[int]] = 384, - num_encoder_layers: Union[int, Tuple[int]] = 4, - encoder_unmasked_dim: Union[int, Tuple[int]] = 256, - query_head_dim: Union[int, Tuple[int]] = 24, - pos_head_dim: Union[int, Tuple[int]] = 4, - value_head_dim: Union[int, Tuple[int]] = 12, - num_heads: Union[int, Tuple[int]] = 8, - feedforward_dim: Union[int, Tuple[int]] = 1536, - cnn_module_kernel: Union[int, Tuple[int]] = 31, - pos_dim: int = 192, - dropout: FloatLike = None, # see code below for default - warmup_batches: float = 4000.0, - causal: bool = False, - chunk_size: Tuple[int] = [-1], - left_context_frames: Tuple[int] = [-1], - ) -> None: - super(Zipformer2, self).__init__() - - if dropout is None: - dropout = ScheduledFloat((0.0, 0.3), - (20000.0, 0.1)) - - def _to_tuple(x): - """ Converts a single int or a 1-tuple of an int to a tuple with the same length - as downsampling_factor""" - if isinstance(x, int): - x = (x,) - if len(x) == 1: - x = x * len(downsampling_factor) - else: - assert len(x) == len(downsampling_factor) and isinstance(x[0], int) - return x - - self.output_downsampling_factor = output_downsampling_factor # int - self.downsampling_factor = downsampling_factor # tuple - self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple - self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple(encoder_unmasked_dim) # tuple - num_encoder_layers = _to_tuple(num_encoder_layers) - self.query_head_dim = query_head_dim = _to_tuple(query_head_dim) - self.value_head_dim = value_head_dim = _to_tuple(value_head_dim) - pos_head_dim = _to_tuple(pos_head_dim) - self.num_heads = num_heads = _to_tuple(num_heads) - feedforward_dim = _to_tuple(feedforward_dim) - self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) - - self.causal = causal - self.chunk_size = chunk_size - self.left_context_frames = left_context_frames - - for u,d in zip(encoder_unmasked_dim, encoder_dim): - assert u <= d - - # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder - encoders = [] - - num_encoders = len(downsampling_factor) - for i in range(num_encoders): - - encoder_layer = Zipformer2EncoderLayer( - embed_dim=encoder_dim[i], - pos_dim=pos_dim, - num_heads=num_heads[i], - query_head_dim=query_head_dim[i], - pos_head_dim=pos_head_dim[i], - value_head_dim=value_head_dim[i], - feedforward_dim=feedforward_dim[i], - dropout=dropout, - cnn_module_kernel=cnn_module_kernel[i], - causal=causal, - ) - - # For the segment of the warmup period, we let the Conv2dSubsampling - # layer learn something. Then we start to warm up the other encoders. - encoder = Zipformer2Encoder( - encoder_layer, - num_encoder_layers[i], - pos_dim=pos_dim, - dropout=dropout, - warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), - warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), - final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5), - ) - - if downsampling_factor[i] != 1: - encoder = DownsampledZipformer2Encoder( - encoder, - dim=encoder_dim[i], - downsample=downsampling_factor[i], - dropout=dropout, - ) - - encoders.append(encoder) - - self.encoders = nn.ModuleList(encoders) - - self.downsample_output = SimpleDownsample(max(encoder_dim), - downsample=output_downsampling_factor, - dropout=dropout) - - def get_feature_masks( - self, - x: Tensor) -> Union[List[float], List[Tensor]]: - """ - In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of - randomized feature masks, one per encoder. - On e.g. 15% of frames, these masks will zero out all enocder dims larger than - some supplied number, e.g. >256, so in effect on those frames we are using - a smaller encoer dim. - - We generate the random masks at this level because we want the 2 masks to 'agree' - all the way up the encoder stack. This will mean that the 1st mask will have - mask values repeated self.zipformer_subsampling_factor times. - - Args: - x: the embeddings (needed for the shape and dtype and device), of shape - (1, batch_size, encoder_dims0) - """ - num_encoders = len(self.encoder_dim) - if not self.training: - return [ 1.0 ] * num_encoders - - (num_frames0, batch_size, _encoder_dims0) = x.shape - - assert self.encoder_dim[0] == _encoder_dims0 - - feature_mask_dropout_prob = 0.125 - - # mask1 shape: (1, batch_size, 1) - mask1 = (torch.rand(1, batch_size, 1, - device=x.device) > - feature_mask_dropout_prob).to(x.dtype) - - # mask2 has additional sequences masked, about twice the number. - mask2 = torch.logical_and(mask1, - (torch.rand(1, batch_size, 1, - device=x.device) > - feature_mask_dropout_prob).to(x.dtype)) - - # dim: (1, batch_size, 2) - mask = torch.cat((mask1, mask2), dim=-1) - - feature_masks = [] - for i in range(num_encoders): - channels = self.encoder_dim[i] - feature_mask = torch.ones(1, batch_size, channels, - dtype=x.dtype, device=x.device) - u1 = self.encoder_unmasked_dim[i] - u2 = u1 + (channels - u1) // 2 - - feature_mask[:, :, u1:u2] *= mask[..., 0:1] - feature_mask[:, :, u2:] *= mask[..., 1:2] - - feature_masks.append(feature_mask) - - return feature_masks - - def get_chunk_info(self) -> Tuple[int, int]: - """ - Returns chunk_size and left_context_chunks. - """ - if not self.causal: - return -1, -1 - - if torch.jit.is_scripting(): - assert len(self.chunk_size) == 1, self.chunk_size - chunk_size = self.chunk_size[0] - else: - chunk_size = random.choice(self.chunk_size) - - if chunk_size == -1: - left_context_chunks = -1 - else: - if torch.jit.is_scripting(): - assert len(self.left_context_frames) == 1, self.left_context_frames - left_context_frames = self.left_context_frames[0] - else: - left_context_frames = random.choice(self.left_context_frames) - # Note: in Python, -1 // n == -1 for n > 0 - left_context_chunks = left_context_frames // chunk_size - if left_context_chunks == 0: - left_context_chunks = 1 - - return chunk_size, left_context_chunks - - def forward( - self, x: Tensor, - x_lens: Tensor, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor]: - """ - Args: - x: - The input tensor. Its shape is (seq_len, batch_size, feature_dim). - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - `x` before padding. - src_key_padding_mask: - The mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - Returns: - Return a tuple containing 2 tensors: - - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) - - lengths, a tensor of shape (batch_size,) containing the number - of frames in `embeddings` before padding. - """ - outputs = [] - if torch.jit.is_scripting(): - feature_masks = [1.0] * len(self.encoder_dim) - else: - feature_masks = self.get_feature_masks(x) - - chunk_size, left_context_chunks = self.get_chunk_info() - - if torch.jit.is_scripting(): - # Not support exporting a model for simulating streaming decoding - attn_mask = None - else: - attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks) - - for i, module in enumerate(self.encoders): - ds = self.downsampling_factor[i] - x = convert_num_channels(x, self.encoder_dim[i]) - - x = module(x, - chunk_size=chunk_size, - feature_mask=feature_masks[i], - src_key_padding_mask=(None if src_key_padding_mask is None - else src_key_padding_mask[...,::ds]), - attn_mask=attn_mask) - outputs.append(x) - - # if the last output has the largest dimension, x will be unchanged, - # it will be the same as outputs[-1]. Otherwise it will be concatenated - # from different pieces of 'outputs', taking each dimension from the - # most recent output that has it present. - x = self._get_full_dim_output(outputs) - x = self.downsample_output(x) - # class Downsample has this rounding behavior.. - assert self.output_downsampling_factor == 2 - if torch.jit.is_scripting(): - lengths = (x_lens + 1) // 2 - else: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - lengths = (x_lens + 1) // 2 - - return x, lengths - - def _get_attn_mask( - self, x: Tensor, - chunk_size: int, - left_context_chunks: int - ) -> Optional[Tensor]: - """ - Return None if chunk_size == -1, else return attention mask of shape - (seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len). True - means a masked position. - Args: - x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim). - chunk_size: chunk size, must divide - """ - if chunk_size <= 0: - return None - assert all(chunk_size % d == 0 for d in self.downsampling_factor) - if left_context_chunks >= 0: - num_encoders = len(self.encoder_dim) - assert all (chunk_size * left_context_chunks >= - (self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i] - for i in range(num_encoders)) - else: - left_context_chunks = 1000000 - - seq_len = x.shape[0] - - # t is frame index, shape (seq_len,) - t = torch.arange(seq_len, dtype=torch.int32, device=x.device) - # c is chunk index for each frame, shape (seq_len,) - if torch.jit.is_scripting(): - c = t // chunk_size - else: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - c = t // chunk_size - src_c = c - tgt_c = c.unsqueeze(-1) - - attn_mask = torch.logical_or(src_c > tgt_c, - src_c < tgt_c - left_context_chunks) - if __name__ == "__main__": - logging.info(f"attn_mask = {attn_mask}") - return attn_mask - - def _get_full_dim_output(self, outputs: List[Tensor]): - num_encoders = len(self.encoder_dim) - assert len(outputs) == num_encoders - output_dim = max(self.encoder_dim) - output_pieces = [ outputs[-1] ] - cur_dim = self.encoder_dim[-1] - for i in range(num_encoders - 2, -1, -1): - d = self.encoder_dim[i] - if d > cur_dim: - this_output = outputs[i] - output_pieces.append(this_output[..., cur_dim:d]) - cur_dim = d - assert cur_dim == output_dim - return torch.cat(output_pieces, dim=-1) - - def streaming_forward( - self, - x: Tensor, - x_lens: Tensor, - states: List[Tensor], - src_key_padding_mask: Tensor, - ) -> Tuple[Tensor, Tensor, List[Tensor]]: - """ - Args: - x: - The input tensor. Its shape is (seq_len, batch_size, feature_dim). - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - `x` before padding. - states: list of cached tensors of all encoder layers. For layer-i, - states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, - cached_conv1, cached_conv2). - src_key_padding_mask: - The mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - Returns: - Return a tuple containing 2 tensors: - - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) - - lengths, a tensor of shape (batch_size,) containing the number - of frames in `embeddings` before padding. - - updated states - """ - outputs = [] - new_states = [] - layer_offset = 0 - - for i, module in enumerate(self.encoders): - num_layers = module.num_layers - ds = self.downsampling_factor[i] - x = convert_num_channels(x, self.encoder_dim[i]) - - x, new_layer_states = module.streaming_forward( - x, - states=states[layer_offset * 6 : (layer_offset + num_layers) * 6], - left_context_len=self.left_context_frames[0] // ds, - src_key_padding_mask=src_key_padding_mask[..., ::ds], - ) - layer_offset += num_layers - outputs.append(x) - new_states += new_layer_states - - # if the last output has the largest dimension, x will be unchanged, - # it will be the same as outputs[-1]. Otherwise it will be concatenated - # from different pieces of 'outputs', taking each dimension from the - # most recent output that has it present. - x = self._get_full_dim_output(outputs) - x = self.downsample_output(x) - # class Downsample has this rounding behavior.. - assert self.output_downsampling_factor == 2 - if torch.jit.is_scripting() or torch.jit.is_tracing(): - lengths = (x_lens + 1) // 2 - else: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - lengths = (x_lens + 1) // 2 - - return x, lengths, new_states - - @torch.jit.export - def get_init_states( - self, - batch_size: int = 1, - device: torch.device = torch.device("cpu"), - ) -> List[Tensor]: - """Get initial states. - - A list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] - is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - """ - states = [] - for i, module in enumerate(self.encoders): - num_layers = module.num_layers - embed_dim = self.encoder_dim[i] - ds = self.downsampling_factor[i] - num_heads = self.num_heads[i] - key_dim = self.query_head_dim[i] * num_heads - value_dim = self.value_head_dim[i] * num_heads - downsample_left = self.left_context_frames[0] // ds - nonlin_attn_head_dim = 3 * embed_dim // 4 - conv_left_pad = self.cnn_module_kernel[i] // 2 - for layer in range(num_layers): - cached_key = torch.zeros(downsample_left, batch_size, key_dim).to(device) - cached_nonlin_attn = torch.zeros(1, batch_size, downsample_left, nonlin_attn_head_dim).to(device) - cached_val1 = torch.zeros(downsample_left, batch_size, value_dim).to(device) - cached_val2 = torch.zeros(downsample_left, batch_size, value_dim).to(device) - cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).to(device) - cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).to(device) - states += [cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2] - - return states - - -def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: - return ScheduledFloat((0.0, x), - (20000.0, ratio * x), - default=x) - - -def _balancer_schedule(min_prob: float): - return ScheduledFloat((0.0, 0.4), (8000.0, min_prob)) - - -class Zipformer2EncoderLayer(nn.Module): - """ - Args: - embed_dim: the number of expected features in the input (required). - nhead: the number of heads in the multiheadattention models (required). - feedforward_dim: the dimension of the feedforward network model (default=2048). - dropout: the dropout value (default=0.1). - cnn_module_kernel (int): Kernel size of convolution module. - - Examples:: - >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) - >>> src = torch.rand(10, 32, 512) - >>> pos_emb = torch.rand(32, 19, 512) - >>> out = encoder_layer(src, pos_emb) - """ - def __init__( - self, - embed_dim: int, - pos_dim: int, - num_heads: int, - query_head_dim: int, - pos_head_dim: int, - value_head_dim: int, - feedforward_dim: int, - dropout: FloatLike = 0.1, - cnn_module_kernel: int = 31, - causal: bool = False, - attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), - conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), - const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0), - ff2_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)), - ff3_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)), - bypass_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.02), default=0), - ) -> None: - super(Zipformer2EncoderLayer, self).__init__() - self.embed_dim = embed_dim - - # self.bypass implements layer skipping as well as bypass; see its default values. - self.bypass = BypassModule(embed_dim, skip_rate=bypass_skip_rate, - straight_through_rate=0) - # bypass_mid is bypass used in the middle of the layer. - self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0) - - # skip probability for dynamic modules (meaning: anything but feedforward). - self.attention_skip_rate = copy.deepcopy(attention_skip_rate) - # an additional skip probability that applies to ConvModule to stop it from - # contributing too much early on. - self.conv_skip_rate = copy.deepcopy(conv_skip_rate) - - # ff2_skip_rate is to prevent the ff2 module from having output that's too big - # compared to its residual. - self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate) - self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate) - - self.const_attention_rate = copy.deepcopy(const_attention_rate) - - self.self_attn_weights = RelPositionMultiheadAttentionWeights( - embed_dim, pos_dim=pos_dim, num_heads=num_heads, - query_head_dim=query_head_dim, pos_head_dim=pos_head_dim, - dropout=0.0, - ) - - self.self_attn1 = SelfAttention(embed_dim, num_heads, - value_head_dim) - - self.self_attn2 = SelfAttention(embed_dim, num_heads, - value_head_dim) - - self.feed_forward1 = FeedforwardModule(embed_dim, - (feedforward_dim * 3) // 4, - dropout) - - self.feed_forward2 = FeedforwardModule(embed_dim, - feedforward_dim, - dropout) - - self.feed_forward3 = FeedforwardModule(embed_dim, - (feedforward_dim * 5) // 4, - dropout) - - self.nonlin_attention = NonlinAttention(embed_dim, - hidden_channels=3 * embed_dim // 4) - - self.conv_module1 = ConvolutionModule(embed_dim, - cnn_module_kernel, - causal=causal) - - self.conv_module2 = ConvolutionModule(embed_dim, - cnn_module_kernel, - causal=causal) - - # TODO: remove it - self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) - - self.norm = BiasNorm(embed_dim) - - self.balancer1 = Balancer( - embed_dim, channel_dim=-1, - min_positive=0.45, max_positive=0.55, - min_abs=0.2, max_abs=4.0, - ) - - # balancer for output of NonlinAttentionModule - self.balancer_na = Balancer( - embed_dim, channel_dim=-1, - min_positive=0.3, max_positive=0.7, - min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)), - prob=0.05, # out of concern for memory usage - ) - - # balancer for output of feedforward2, prevent it from staying too - # small. give this a very small probability, even at the start of - # training, it's to fix a rare problem and it's OK to fix it slowly. - self.balancer_ff2 = Balancer( - embed_dim, channel_dim=-1, - min_positive=0.3, max_positive=0.7, - min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0), - max_abs=2.0, - prob=0.05, - ) - - self.balancer_ff3 = Balancer( - embed_dim, channel_dim=-1, - min_positive=0.3, max_positive=0.7, - min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0), - max_abs=4.0, - prob=0.05, - ) - - self.whiten = Whiten(num_groups=1, - whitening_limit=_whitening_schedule(4.0, ratio=3.0), - prob=(0.025, 0.25), - grad_scale=0.01) - - self.balancer2 = Balancer( - embed_dim, channel_dim=-1, - min_positive=0.45, max_positive=0.55, - min_abs=0.1, max_abs=4.0, - ) - - def get_sequence_dropout_mask(self, x: Tensor, dropout_rate: float) -> Optional[Tensor]: - if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting(): - return None - batch_size = x.shape[1] - mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype) - return mask - - def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor: - """ - Apply sequence-level dropout to x. - x shape: (seq_len, batch_size, embed_dim) - """ - dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate) - if dropout_mask is None: - return x - else: - return x * dropout_mask - - def forward( - self, - src: Tensor, - pos_emb: Tensor, - chunk_size: int = -1, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - """ - Pass the input through the encoder layer. - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim) - chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. - feature_mask: something that broadcasts with src, that we'll multiply `src` - by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) - attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), - interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). - True means masked position. May be None. - src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - - Returns: - A tensor which has the same shape as src - """ - src_orig = src - - # dropout rate for non-feedforward submodules - if torch.jit.is_scripting(): - attention_skip_rate = 0.0 - else: - attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0 - - # attn_weights: (num_heads, batch_size, seq_len, seq_len) - attn_weights = self.self_attn_weights( - src, - pos_emb=pos_emb, - attn_mask=attn_mask, - key_padding_mask=src_key_padding_mask, - ) - - src = src + self.feed_forward1(src) - - self_attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate) - - selected_attn_weights = attn_weights[0:1] - if torch.jit.is_scripting(): - pass - elif not self.training and random.random() < float(self.const_attention_rate): - # Make attention weights constant. The intention is to - # encourage these modules to do something similar to an - # averaging-over-time operation. - # only need the mask, can just use the 1st one and expand later - selected_attn_weights = selected_attn_weights[0:1] - selected_attn_weights = (selected_attn_weights > 0.0).to(selected_attn_weights.dtype) - selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True)) - - na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights)) - - src = src + (na if self_attn_dropout_mask is None else na * self_attn_dropout_mask) - - self_attn = self.self_attn1(src, attn_weights) - - src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask) - - if torch.jit.is_scripting(): - conv_skip_rate = 0.0 - else: - conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 - src = src + self.sequence_dropout(self.conv_module1(src, chunk_size=chunk_size, - src_key_padding_mask=src_key_padding_mask), - conv_skip_rate) - - if torch.jit.is_scripting(): - ff2_skip_rate = 0.0 - else: - ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0 - src = src + self.sequence_dropout(self.balancer_ff2(self.feed_forward2(src)), - ff2_skip_rate) - - # bypass in the middle of the layer. - src = self.bypass_mid(src_orig, src) - - self_attn = self.self_attn2(src, attn_weights) - - src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask) - - if torch.jit.is_scripting(): - conv_skip_rate = 0.0 - else: - conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 - src = src + self.sequence_dropout(self.conv_module2(src, chunk_size=chunk_size, - src_key_padding_mask=src_key_padding_mask), - conv_skip_rate) - - if torch.jit.is_scripting(): - ff3_skip_rate = 0.0 - else: - ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0 - src = src + self.sequence_dropout(self.balancer_ff3(self.feed_forward3(src)), - ff3_skip_rate) - - src = self.balancer1(src) - src = self.norm(src) - - src = self.bypass(src_orig, src) - - src = self.balancer2(src) - src = self.whiten(src) - - return src - - def streaming_forward( - self, - src: Tensor, - pos_emb: Tensor, - cached_key: Tensor, - cached_nonlin_attn: Tensor, - cached_val1: Tensor, - cached_val2: Tensor, - cached_conv1: Tensor, - cached_conv2: Tensor, - left_context_len: int, - src_key_padding_mask: Tensor, - ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: - """Pass the input through the encoder layer in streaming forward mode. - - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - pos_emb: (1, left_context_len+2*seq_len-1, pos_emb_dim) or - (batch_size, left_context_len+2*seq_len-1, pos_emb_dim) - cached_key: cached attention key tensor of left context, - of shape (left_context_len, batch_size, key_dim) - cached_nonlin_attn: left context for nonlin_attention module, a Tensor of shape - (num_heads, batch_size, left_context_len, head_dim) - cached_val1: cached left context for the first attention module, - of shape (left_context_len, batch_size, value_dim) - cached_val2: cached left context for the second attention module, - of shape (left_context_len, batch_size, value_dim) - cached_conv1: cached left context for the first convolution module, - of shape (batch_size, channels, left_pad) - cached_conv2: cached left context for the second convolution module, - of shape (batch_size, channels, left_pad) - left_context_len: number of left context frames. - src_key_padding_mask: the mask for padding, of shape - (batch_size, left_context_len + seq_len); True means masked position. - May be None. - - Returns: - - x, with the same shape as src - - updated cached_key - - updated cached_nonlin_attn - - updated cached_val1 - - updated cached_val2 - - updated cached_conv1 - - updated cached_conv2 - """ - src_orig = src - - # attn_weights: (num_heads, batch_size, seq_len, seq_len) - attn_weights, cached_key = self.self_attn_weights.streaming_forward( - src, - pos_emb=pos_emb, - cached_key=cached_key, - left_context_len=left_context_len, - key_padding_mask=src_key_padding_mask, - ) - - src = src + self.feed_forward1(src) - - na, cached_nonlin_attn = self.nonlin_attention.streaming_forward( - src, - attn_weights[0:1], - cached_x=cached_nonlin_attn, - left_context_len=left_context_len, - ) - src = src + na - - self_attn, cached_val1 = self.self_attn1.streaming_forward( - src, - attn_weights=attn_weights, - cached_val=cached_val1, - left_context_len=left_context_len, - ) - src = src + self_attn - - src_conv, cached_conv1 = self.conv_module1.streaming_forward( - src, - cache=cached_conv1, - src_key_padding_mask=src_key_padding_mask[:, left_context_len:], - ) - src = src + src_conv - - src = src + self.feed_forward2(src) - - # bypass in the middle of the layer. - src = self.bypass_mid(src_orig, src) - - self_attn, cached_val2 = self.self_attn2.streaming_forward( - src, - attn_weights=attn_weights, - cached_val=cached_val2, - left_context_len=left_context_len, - ) - src = src + self_attn - - src_conv, cached_conv2 = self.conv_module2.streaming_forward( - src, - cache=cached_conv2, - src_key_padding_mask=src_key_padding_mask[:, left_context_len:], - ) - src = src + src_conv - - src = src + self.feed_forward3(src) - - src = self.norm(src) - - src = self.bypass(src_orig, src) - - return ( - src, - cached_key, - cached_nonlin_attn, - cached_val1, - cached_val2, - cached_conv1, - cached_conv2, - ) - - -class Zipformer2Encoder(nn.Module): - r"""Zipformer2Encoder is a stack of N encoder layers - - Args: - encoder_layer: an instance of the Zipformer2EncoderLayer() class (required). - num_layers: the number of sub-encoder-layers in the encoder (required). - pos_dim: the dimension for the relative positional encoding - - Examples:: - >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) - >>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6) - >>> src = torch.rand(10, 32, 512) - >>> out = zipformer_encoder(src) - """ - def __init__( - self, - encoder_layer: nn.Module, - num_layers: int, - pos_dim: int, - dropout: float, - warmup_begin: float, - warmup_end: float, - initial_layerdrop_rate: float = 0.5, - final_layerdrop_rate: float = 0.05, - ) -> None: - super().__init__() - self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.15, - length_factor=1.0) - - self.layers = nn.ModuleList( - [copy.deepcopy(encoder_layer) for i in range(num_layers)] - ) - self.num_layers = num_layers - - assert 0 <= warmup_begin <= warmup_end - - delta = (1. / num_layers) * (warmup_end - warmup_begin) - cur_begin = warmup_begin # interpreted as a training batch index - for i in range(num_layers): - cur_end = cur_begin + delta - self.layers[i].bypass.skip_rate = ScheduledFloat((cur_begin, initial_layerdrop_rate), - (cur_end, final_layerdrop_rate), - default=0.0) - cur_begin = cur_end - - def forward( - self, - src: Tensor, - chunk_size: int = -1, - feature_mask: Union[Tensor, float] = 1.0, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - r"""Pass the input through the encoder layers in turn. - - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. - feature_mask: something that broadcasts with src, that we'll multiply `src` - by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) - attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), - interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). - True means masked position. May be None. - src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - - Returns: a Tensor with the same shape as src. - """ - pos_emb = self.encoder_pos(src) - output = src - - if not torch.jit.is_scripting(): - output = output * feature_mask - - for i, mod in enumerate(self.layers): - output = mod( - output, - pos_emb, - chunk_size=chunk_size, - attn_mask=attn_mask, - src_key_padding_mask=src_key_padding_mask, - ) - - if not torch.jit.is_scripting(): - output = output * feature_mask - - return output - - def streaming_forward( - self, - src: Tensor, - states: List[Tensor], - left_context_len: int, - src_key_padding_mask: Tensor, - ) -> Tuple[Tensor, List[Tensor]]: - r"""Pass the input through the encoder layers in turn. - - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is - (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - left_context_len: Number of left context frames. - src_key_padding_mask: the mask for padding, of shape - (batch_size, left_context_len + seq_len); True means masked position. - May be None. - - Returns: - - output, a Tensor with the same shape as src. - - updated states - """ - pos_emb = self.encoder_pos(src, left_context_len) - output = src - - new_states = [] - for i, mod in enumerate(self.layers): - ( - cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2 - ) = states[i * 6: (i + 1) * 6] - ( - output, - new_cached_key, - new_cached_nonlin_attn, - new_cached_val1, - new_cached_val2, - new_cached_conv1, - new_cached_conv2 - ) = mod.streaming_forward( - output, - pos_emb, - cached_key=cached_key, - cached_nonlin_attn=cached_nonlin_attn, - cached_val1=cached_val1, - cached_val2=cached_val2, - cached_conv1=cached_conv1, - cached_conv2=cached_conv2, - left_context_len=left_context_len, - src_key_padding_mask=src_key_padding_mask, - ) - new_states += [ - new_cached_key, - new_cached_nonlin_attn, - new_cached_val1, - new_cached_val2, - new_cached_conv1, - new_cached_conv2, - ] - - return output, new_states - - -class BypassModule(nn.Module): - """ - An nn.Module that implements a learnable bypass scale, and also randomized per-sequence - layer-skipping. The bypass is limited during early stages of training to be close to - "straight-through", i.e. to not do the bypass operation much initially, in order to - force all the modules to learn something. - """ - def __init__( - self, - embed_dim: int, - skip_rate: FloatLike = 0.0, - straight_through_rate: FloatLike = 0.0, - scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0), - scale_max: FloatLike = 1.0): - super().__init__() - self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) - self.skip_rate = copy.deepcopy(skip_rate) - self.straight_through_rate = copy.deepcopy(straight_through_rate) - self.scale_min = copy.deepcopy(scale_min) - self.scale_max = copy.deepcopy(scale_max) - - def _get_bypass_scale(self, batch_size: int): - # returns bypass-scale of shape (num_channels,), - # or (batch_size, num_channels,). This is actually the - # scale on the non-residual term, so 0 correponds to bypassing - # this module. - if torch.jit.is_scripting() or not self.training: - return self.bypass_scale - else: - ans = limit_param_value(self.bypass_scale, - min=float(self.scale_min), - max=float(self.scale_max)) - skip_rate = float(self.skip_rate) - if skip_rate != 0.0: - mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate - ans = ans * mask - # now ans is of shape (batch_size, num_channels), and is zero for sequences - # on which we have randomly chosen to do layer-skipping. - straight_through_rate = float(self.straight_through_rate) - if straight_through_rate != 0.0: - mask = torch.rand((batch_size, 1), device=ans.device) < straight_through_rate - ans = torch.maximum(ans, mask.to(ans.dtype)) - return ans - - def forward(self, - src_orig: Tensor, - src: Tensor): - """ - Args: src_orig and src are both of shape (seq_len, batch_size, num_channels) - Returns: something with the same shape as src and src_orig - """ - bypass_scale = self._get_bypass_scale(src.shape[1]) - return src_orig + (src - src_orig) * bypass_scale - - -class DownsampledZipformer2Encoder(nn.Module): - r""" - DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame rate, - after convolutional downsampling, and then upsampled again at the output, and combined - with the origin input, so that the output has the same shape as the input. - """ - def __init__(self, - encoder: nn.Module, - dim: int, - downsample: int, - dropout: FloatLike): - super(DownsampledZipformer2Encoder, self).__init__() - self.downsample_factor = downsample - self.downsample = SimpleDownsample(dim, - downsample, dropout) - self.num_layers = encoder.num_layers - self.encoder = encoder - self.upsample = SimpleUpsample(dim, downsample) - self.out_combiner = BypassModule(dim, straight_through_rate=0) - - def forward( - self, - src: Tensor, - chunk_size: int = -1, - feature_mask: Union[Tensor, float] = 1.0, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - r"""Downsample, go through encoder, upsample. - - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - feature_mask: something that broadcasts with src, that we'll multiply `src` - by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) - attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), - interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). - True means masked position. May be None. - src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - - Returns: a Tensor with the same shape as src. - """ - src_orig = src - src = self.downsample(src) - ds = self.downsample_factor - if attn_mask is not None: - attn_mask = attn_mask[::ds,::ds] - - src = self.encoder( - src, - chunk_size=chunk_size // ds, - feature_mask=feature_mask, - attn_mask=attn_mask, - src_key_padding_mask=src_key_padding_mask, - ) - src = self.upsample(src) - # remove any extra frames that are not a multiple of downsample_factor - src = src[:src_orig.shape[0]] - - return self.out_combiner(src_orig, src) - - def streaming_forward( - self, - src: Tensor, - states: List[Tensor], - left_context_len: int, - src_key_padding_mask: Tensor, - ) -> Tuple[Tensor, List[Tensor]]: - r"""Downsample, go through encoder, upsample, in streaming forward mode. - - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is - (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - left_context_len: Number of left context frames. - src_key_padding_mask: the mask for padding, of shape (batch_size, left_context_len+seq_len); - True means masked position. May be None. - - Returns: - - output, a Tensor with the same shape as src. - - updated states - """ - src_orig = src - src = self.downsample(src) - - src, new_states = self.encoder.streaming_forward( - src, - states=states, - left_context_len=left_context_len, - src_key_padding_mask=src_key_padding_mask, - ) - src = self.upsample(src) - # remove any extra frames that are not a multiple of downsample_factor - src = src[:src_orig.shape[0]] - - return self.out_combiner(src_orig, src), new_states - - -class SimpleDownsample(torch.nn.Module): - """ - Does downsampling with attention, by weighted sum, and a projection.. - """ - def __init__(self, - channels: int, - downsample: int, - dropout: FloatLike): - super(SimpleDownsample, self).__init__() - - self.bias = nn.Parameter(torch.zeros(downsample)) - - self.name = None # will be set from training code - self.dropout = copy.deepcopy(dropout) - - self.downsample = downsample - - def forward(self, - src: Tensor) -> Tensor: - """ - x: (seq_len, batch_size, in_channels) - Returns a tensor of shape - ( (seq_len+downsample-1)//downsample, batch_size, channels) - """ - (seq_len, batch_size, in_channels) = src.shape - ds = self.downsample - d_seq_len = (seq_len + ds - 1) // ds - - # Pad to an exact multiple of self.downsample - if seq_len != d_seq_len * ds: - # right-pad src, repeating the last element. - pad = d_seq_len * ds - seq_len - src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2]) - src = torch.cat((src, src_extra), dim=0) - assert src.shape[0] == d_seq_len * ds - - src = src.reshape(d_seq_len, ds, batch_size, in_channels) - - weights = self.bias.softmax(dim=0) - # weights: (downsample, 1, 1) - weights = weights.unsqueeze(-1).unsqueeze(-1) - - # ans1 is the first `in_channels` channels of the output - ans = (src * weights).sum(dim=1) - - return ans - - -class SimpleUpsample(torch.nn.Module): - """ - A very simple form of upsampling that mostly just repeats the input, but - also adds a position-specific bias. - """ - def __init__(self, - num_channels: int, - upsample: int): - super(SimpleUpsample, self).__init__() - self.upsample = upsample - - def forward(self, - src: Tensor) -> Tensor: - """ - x: (seq_len, batch_size, num_channels) - Returns a tensor of shape - ( (seq_len*upsample), batch_size, num_channels) - """ - upsample = self.upsample - (seq_len, batch_size, num_channels) = src.shape - src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels) - src = src.reshape(seq_len * upsample, batch_size, num_channels) - return src - - -class CompactRelPositionalEncoding(torch.nn.Module): - """ - Relative positional encoding module. This version is "compact" meaning it is able to encode - the important information about the relative position in a relatively small number of dimensions. - The goal is to make it so that small differences between large relative offsets (e.g. 1000 vs. 1001) - make very little difference to the embedding. Such differences were potentially important - when encoding absolute position, but not important when encoding relative position because there - is now no need to compare two large offsets with each other. - - Our embedding works done by projecting the interval [-infinity,infinity] to a finite interval - using the atan() function, before doing the fourier transform of that fixed interval. The - atan() function would compress the "long tails" too small, - making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic - function to compress large offsets to a smaller range before applying atan(). - Scalings are chosen in such a way that the embedding can clearly distinguish invidual offsets as long - as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embedding_dim) - - - Args: - embed_dim: Embedding dimension. - dropout_rate: Dropout rate. - max_len: Maximum input length: just a heuristic for initialization. - length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives - less weight to small differences of offset near the origin. - """ - def __init__( - self, embed_dim: int, - dropout_rate: FloatLike, - max_len: int = 1000, - length_factor: float = 1.0, - ) -> None: - """Construct a CompactRelPositionalEncoding object.""" - super(CompactRelPositionalEncoding, self).__init__() - self.embed_dim = embed_dim - assert embed_dim % 2 == 0 - self.dropout = Dropout2(dropout_rate) - self.pe = None - assert length_factor >= 1.0 - self.length_factor = length_factor - self.extend_pe(torch.tensor(0.0).expand(max_len)) - - def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None: - """Reset the positional encodings.""" - T = x.size(0) + left_context_len - - if self.pe is not None: - # self.pe contains both positive and negative parts - # the length of self.pe is 2 * input_len - 1 - if self.pe.size(0) >= T * 2 - 1: - # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - - # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ] - x = torch.arange(-(T-1), T, - device=x.device).to(torch.float32).unsqueeze(1) - - freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device) - - # `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution - # for small time offsets but less resolution for large time offsets. - compression_length = (self.embed_dim ** 0.5) - # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity; - # but it does so more slowly than T for large absolute values of T. - # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which - # is important. - x_compressed = compression_length * x.sign() * ((x.abs() + compression_length).log() - math.log(compression_length)) - - # if self.length_factor == 1.0, then length_scale is chosen so that the - # FFT can exactly separate points close to the origin (T == 0). So this - # part of the formulation is not really heuristic. - # But empirically, for ASR at least, length_factor > 1.0 seems to work better. - length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi) - - # note for machine implementations: if atan is not available, we can use: - # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2) - # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x)) - x_atan = (x_compressed / length_scale).atan() # results between -pi and pi - - cosines = (x_atan * freqs).cos() - sines = (x_atan * freqs).sin() - - pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device) - pe[:, 0::2] = cosines - pe[:, 1::2] = sines - pe[:, -1] = 1.0 # for bias. - - self.pe = pe.to(dtype=x.dtype) - - def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor: - """Create positional encoding. - - Args: - x (Tensor): Input tensor (time, batch, `*`). - left_context_len: (int): Length of cached left context. - - Returns: - positional embedding, of shape (batch, left_context_len + 2*time-1, `*`). - """ - self.extend_pe(x, left_context_len) - x_size_left = x.size(0) + left_context_len - # length of positive side: x.size(0) + left_context_len - # length of negative side: x.size(0) - pos_emb = self.pe[ - self.pe.size(0) // 2 - - x_size_left - + 1 : self.pe.size(0) // 2 # noqa E203 - + x.size(0), - : - ] - pos_emb = pos_emb.unsqueeze(0) - return self.dropout(pos_emb) - - -class RelPositionMultiheadAttentionWeights(nn.Module): - r"""Module that computes multi-head attention weights with relative position encoding. - Various other modules consume the resulting attention weights: see, for example, the - SimpleAttention module which allows you to compute conventional attention. - - This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", - we have to write up the differences. - - - Args: - embed_dim: number of channels at the input to this module, e.g. 256 - pos_dim: dimension of the positional encoding vectors, e.g. 128. - num_heads: number of heads to compute weights for, e.g. 8 - query_head_dim: dimension of the query (and key), per head. e.g. 24. - pos_head_dim: dimension of the projected positional encoding per head, e.g. 4. - dropout: dropout probability for attn_output_weights. Default: 0.0. - pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on - any given call to forward(), in training time. - """ - - def __init__( - self, - embed_dim: int, - pos_dim: int, - num_heads: int, - query_head_dim: int, - pos_head_dim: int, - dropout: float = 0.0, - pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), - (4000.0, 0.0)) - ) -> None: - super().__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.query_head_dim = query_head_dim - self.pos_head_dim = pos_head_dim - self.dropout = dropout - self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate) - self.name = None # will be overwritten in training code; for diagnostics. - - key_head_dim = query_head_dim - in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads - - # the initial_scale is supposed to take over the "scaling" factor of - # head_dim ** -0.5 that has been used in previous forms of attention, - # dividing it between the query and key. Note: this module is intended - # to be used with the ScaledAdam optimizer; with most other optimizers, - # it would be necessary to apply the scaling factor in the forward function. - self.in_proj = ScaledLinear(embed_dim, in_proj_dim, bias=True, - initial_scale=query_head_dim**-0.25) - - self.whiten_keys = Whiten(num_groups=num_heads, - whitening_limit=_whitening_schedule(3.0), - prob=(0.025, 0.25), - grad_scale=0.025) - - # add a balancer for the keys that runs with very small probability, and - # tries to enforce that all dimensions have mean around zero. The - # weights produced by this module are invariant to adding a constant to - # the keys, so the derivative of the bias is mathematically zero; but - # due to how Adam/ScaledAdam work, it can learn a fairly large nonzero - # bias because the small numerical roundoff tends to have a non-random - # sign. This module is intended to prevent that. Use a very small - # probability; that should be suffixient to fix the problem. - self.balance_keys = Balancer(key_head_dim * num_heads, - channel_dim=-1, - min_positive=0.4, - max_positive=0.6, - min_abs=0.0, - max_abs=100.0, - prob=0.025) - - # linear transformation for positional encoding. - self.linear_pos = ScaledLinear(pos_dim, - num_heads * pos_head_dim, - bias=False, - initial_scale=0.05) - - # the following are for diagnosics only, see --print-diagnostics option - self.copy_pos_query = Identity() - self.copy_query = Identity() - - def forward( - self, - x: Tensor, - pos_emb: Tensor, - key_padding_mask: Optional[Tensor] = None, - attn_mask: Optional[Tensor] = None, - ) -> Tensor: - r""" - Args: - x: input of shape (seq_len, batch_size, embed_dim) - pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim) - key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that - are True in this mask will be ignored as sources in the attention weighting. - attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len), - interpreted as ([batch_size,] tgt_seq_len, src_seq_len) - saying which positions are allowed to attend to which other positions. - Returns: - a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len) - interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). - """ - x = self.in_proj(x) - query_head_dim = self.query_head_dim - pos_head_dim = self.pos_head_dim - num_heads = self.num_heads - - seq_len, batch_size, _ = x.shape - - query_dim = query_head_dim * num_heads - - # self-attention - q = x[...,0:query_dim] - k = x[...,query_dim:2*query_dim] - # p is the position-encoding query - p = x[...,2*query_dim:] - assert p.shape[-1] == num_heads * pos_head_dim - - q = self.copy_query(q) # for diagnostics only, does nothing. - k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass. - p = self.copy_pos_query(p) # for diagnostics only, does nothing. - - q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) - p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) - k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) - - # time1 refers to target, time2 refers to source. - q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) - p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) - k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) - - attn_scores = torch.matmul(q, k) - - use_pos_scores = False - if torch.jit.is_scripting(): - # We can't put random.random() in the same line - use_pos_scores = True - elif not self.training or random.random() >= float(self.pos_emb_skip_rate): - use_pos_scores = True - - if use_pos_scores: - pos_emb = self.linear_pos(pos_emb) - seq_len2 = 2 * seq_len - 1 - pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1) - # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) - - # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) - # [where seq_len2 represents relative position.] - pos_scores = torch.matmul(p, pos_emb) - # the following .as_strided() expression converts the last axis of pos_scores from relative - # to absolute position. I don't know whether I might have got the time-offsets backwards or - # not, but let this code define which way round it is supposed to be. - pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, seq_len), - (pos_scores.stride(0), - pos_scores.stride(1), - pos_scores.stride(2)-pos_scores.stride(3), - pos_scores.stride(3)), - storage_offset=pos_scores.stride(3) * (seq_len - 1)) - - attn_scores = attn_scores + pos_scores - - if torch.jit.is_scripting(): - pass - elif self.training and random.random() < 0.1: - # This is a harder way of limiting the attention scores to not be - # too large. It incurs a penalty if any of them has an absolute - # value greater than 50.0. this should be outside the normal range - # of the attention scores. We use this mechanism instead of, say, - # something added to the loss function involving the entropy, - # because once the entropy gets very small gradients through the - # softmax can become very small, and we'd get zero derivatives. The - # choices of 1.0e-04 as the scale on the penalty makes this - # mechanism vulnerable to the absolute scale of the loss function, - # but we view this as a failsafe to avoid "implausible" parameter - # values rather than a regularization method that should be active - # under normal circumstances. - attn_scores = penalize_abs_values_gt(attn_scores, - limit=25.0, - penalty=1.0e-04, - name=self.name) - - assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) - - if attn_mask is not None: - assert attn_mask.dtype == torch.bool - # use -1000 to avoid nan's where attn_mask and key_padding_mask make - # all scores zero. It's important that this be large enough that exp(-1000) - # is exactly zero, for reasons related to const_attention_rate, it - # compares the final weights with zero. - attn_scores = attn_scores.masked_fill(attn_mask, -1000) - - if key_padding_mask is not None: - assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape - attn_scores = attn_scores.masked_fill( - key_padding_mask.unsqueeze(1), - -1000, - ) - - # We use our own version of softmax, defined in scaling.py, which should - # save a little of the memory used in backprop by, if we are in - # automatic mixed precision mode (amp / autocast), by only storing the - # half-precision output for backprop purposes. - attn_weights = softmax(attn_scores, dim=-1) - - if torch.jit.is_scripting(): - pass - elif random.random() < 0.001 and not self.training: - self._print_attn_entropy(attn_weights) - - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) - - return attn_weights - - def streaming_forward( - self, - x: Tensor, - pos_emb: Tensor, - cached_key: Tensor, - left_context_len: int, - key_padding_mask: Tensor, - ) -> Tuple[Tensor, Tensor]: - r""" - Args: - x: input of shape (seq_len, batch_size, embed_dim) - pos_emb: Positional embedding tensor, of shape (1, left_context_len+2*seq_len-1, pos_dim) - cached_key: cached attention key tensor of left context, - of shape (left_context_len, batch_size, key_dim) - left_context_len: number of left context frames. - key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that - are True in this mask will be ignored as sources in the attention weighting. - - Returns: - - attention weights, of shape (hum_heads, batch_size, seq_len, seq_len2), - interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). - - updated cached attention key tensor of left context. - """ - x = self.in_proj(x) - query_head_dim = self.query_head_dim - pos_head_dim = self.pos_head_dim - num_heads = self.num_heads - - seq_len, batch_size, _ = x.shape - - query_dim = query_head_dim * num_heads - - # self-attention - q = x[...,0:query_dim] - k = x[...,query_dim:2*query_dim] - # p is the position-encoding query - p = x[...,2*query_dim:] - assert p.shape[-1] == num_heads * pos_head_dim - - # Pad cached left contexts - assert cached_key.shape[0] == left_context_len, (cached_key.shape[0], left_context_len) - k = torch.cat([cached_key, k], dim=0) - # Update cached left contexts - cached_key = k[-left_context_len:, ...] - - # The length of key - k_len = k.shape[0] - - q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) - p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) - k = k.reshape(k_len, batch_size, num_heads, query_head_dim) - - # time1 refers to target, time2 refers to source. - q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) - p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) - k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) - - attn_scores = torch.matmul(q, k) - - pos_emb = self.linear_pos(pos_emb) - seq_len2 = 2 * seq_len - 1 + left_context_len - pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1) - # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) - - # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) - # [where seq_len2 represents relative position.] - pos_scores = torch.matmul(p, pos_emb) - # the following .as_strided() expression converts the last axis of pos_scores from relative - # to absolute position. I don't know whether I might have got the time-offsets backwards or - # not, but let this code define which way round it is supposed to be. - pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, k_len), - (pos_scores.stride(0), - pos_scores.stride(1), - pos_scores.stride(2)-pos_scores.stride(3), - pos_scores.stride(3)), - storage_offset=pos_scores.stride(3) * (seq_len - 1)) - - attn_scores = attn_scores + pos_scores - - assert attn_scores.shape == (num_heads, batch_size, seq_len, k_len), attn_scores.shape - - if key_padding_mask is not None: - assert key_padding_mask.shape == (batch_size, k_len), key_padding_mask.shape - attn_scores = attn_scores.masked_fill( - key_padding_mask.unsqueeze(1), - -1000, - ) - - attn_weights = attn_scores.softmax(dim=-1) - - return attn_weights, cached_key - - def _print_attn_entropy( - self, - attn_weights: Tensor): - # attn_weights: (num_heads, batch_size, seq_len, seq_len) - (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape - - with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): - attn_weights = attn_weights.to(torch.float32) - attn_weights_entropy = -((attn_weights + 1.0e-20).log() * attn_weights).sum( - dim=-1).mean(dim=(1,2)) - logging.info(f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}") - - -class SelfAttention(nn.Module): - """ - The simplest possible attention module. This one works with already-computed attention - weights, e.g. as computed by RelPositionMultiheadAttentionWeights. - - Args: - embed_dim: the input and output embedding dimension - num_heads: the number of attention heads - value_head_dim: the value dimension per head - """ - def __init__( - self, - embed_dim: int, - num_heads: int, - value_head_dim: int, - ) -> None: - super().__init__() - self.in_proj = nn.Linear(embed_dim, - num_heads * value_head_dim, - bias=True) - - self.out_proj = ScaledLinear(num_heads * value_head_dim, - embed_dim, bias=True, - initial_scale=0.05) - - self.whiten = Whiten(num_groups=1, - whitening_limit=_whitening_schedule(7.5, ratio=3.0), - prob=(0.025, 0.25), - grad_scale=0.01) - - def forward( - self, - x: Tensor, - attn_weights: Tensor, - ) -> Tensor: - """ - Args: - x: input tensor, of shape (seq_len, batch_size, embed_dim) - attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), - with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect - attn_weights.sum(dim=-1) == 1. - Returns: - a tensor with the same shape as x. - """ - (seq_len, batch_size, embed_dim) = x.shape - num_heads = attn_weights.shape[0] - assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) - - x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) - x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) - # now x: (num_heads, batch_size, seq_len, value_head_dim) - value_head_dim = x.shape[-1] - - # todo: see whether there is benefit in overriding matmul - x = torch.matmul(attn_weights, x) - # v: (num_heads, batch_size, seq_len, value_head_dim) - - x = x.permute(2, 1, 0, 3).contiguous().view( - seq_len, batch_size, num_heads * value_head_dim) - - # returned value is of shape (seq_len, batch_size, embed_dim), like the input. - x = self.out_proj(x) - x = self.whiten(x) - - return x - - def streaming_forward( - self, - x: Tensor, - attn_weights: Tensor, - cached_val: Tensor, - left_context_len: int, - ) -> Tuple[Tensor, Tensor]: - """ - Args: - x: input tensor, of shape (seq_len, batch_size, embed_dim) - attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), - with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect - attn_weights.sum(dim=-1) == 1. - cached_val: cached attention value tensor of left context, - of shape (left_context_len, batch_size, value_dim) - left_context_len: number of left context frames. - - Returns: - - attention weighted output, a tensor with the same shape as x. - - updated cached attention value tensor of left context. - """ - (seq_len, batch_size, embed_dim) = x.shape - num_heads = attn_weights.shape[0] - seq_len2 = seq_len + left_context_len - assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len2) - - x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) - - # Pad cached left contexts - assert cached_val.shape[0] == left_context_len, (cached_val.shape[0], left_context_len) - x = torch.cat([cached_val, x], dim=0) - # Update cached left contexts - cached_val = x[-left_context_len:, ...] - - x = x.reshape(seq_len2, batch_size, num_heads, -1).permute(2, 1, 0, 3) - # now x: (num_heads, batch_size, seq_len, value_head_dim) - value_head_dim = x.shape[-1] - - # todo: see whether there is benefit in overriding matmul - x = torch.matmul(attn_weights, x) - # v: (num_heads, batch_size, seq_len, value_head_dim) - - x = x.permute(2, 1, 0, 3).contiguous().view( - seq_len, batch_size, num_heads * value_head_dim) - - # returned value is of shape (seq_len, batch_size, embed_dim), like the input. - x = self.out_proj(x) - - return x, cached_val - - -class FeedforwardModule(nn.Module): - """Feedforward module in Zipformer2 model. - """ - def __init__(self, - embed_dim: int, - feedforward_dim: int, - dropout: FloatLike): - super(FeedforwardModule, self).__init__() - self.in_proj = nn.Linear(embed_dim, feedforward_dim) - - self.hidden_balancer = Balancer(feedforward_dim, - channel_dim=-1, - min_positive=0.3, - max_positive=1.0, - min_abs=0.75, - max_abs=5.0) - - # shared_dim=0 means we share the dropout mask along the time axis - self.out_proj = ActivationDropoutAndLinear(feedforward_dim, embed_dim, - activation='SwooshL', - dropout_p=dropout, - dropout_shared_dim=0, bias=True, - initial_scale=0.1) - - self.out_whiten = Whiten(num_groups=1, - whitening_limit=_whitening_schedule(7.5), - prob=(0.025, 0.25), - grad_scale=0.01) - - def forward(self, x: Tensor): - x = self.in_proj(x) - x = self.hidden_balancer(x) - # out_proj contains SwooshL activation, then dropout, then linear. - x = self.out_proj(x) - x = self.out_whiten(x) - return x - - -class NonlinAttention(nn.Module): - """This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed - from the attention module) in place of actual convolution. We also took out the second nonlinearity, the - one after the attention mechanism. - - Args: - channels (int): The number of channels of conv layers. - """ - - def __init__( - self, - channels: int, - hidden_channels: int, - ) -> None: - super().__init__() - - self.hidden_channels = hidden_channels - - self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True) - - # balancer that goes before the sigmoid. Have quite a large min_abs value, at 2.0, - # because we noticed that well-trained instances of this module have abs-value before the sigmoid - # starting from about 3, and poorly-trained instances of the module have smaller abs values - # before the sigmoid. - self.balancer = Balancer( - hidden_channels, channel_dim=-1, - min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)), - max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)), - min_abs=0.5, - max_abs=5.0, - ) - self.tanh = nn.Tanh() - - self.identity1 = Identity() # for diagnostics. - self.identity2 = Identity() # for diagnostics. - self.identity3 = Identity() # for diagnostics. - - self.out_proj = ScaledLinear(hidden_channels, channels, - bias=True, - initial_scale=0.05) - - self.whiten1 = Whiten(num_groups=1, - whitening_limit=_whitening_schedule(5.0), - prob=(0.025, 0.25), - grad_scale=0.01) - - self.whiten2 = Whiten(num_groups=1, - whitening_limit=_whitening_schedule(5.0, ratio=3.0), - prob=(0.025, 0.25), - grad_scale=0.01) - - def forward( - self, - x: Tensor, - attn_weights: Tensor, - ) -> Tensor: - """. - Args: - x: a Tensor of shape (seq_len, batch_size, num_channels) -attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) - Returns: - a Tensor with the same shape as x - """ - x = self.in_proj(x) - - (seq_len, batch_size, _) = x.shape - hidden_channels = self.hidden_channels - - s, x, y = x.chunk(3, dim=-1) - - # s will go through tanh. - - s = self.balancer(s) - s = self.tanh(s) - - s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) - x = self.whiten1(x) - x = x * s - x = self.identity1(x) # diagnostics only, it's the identity. - - (seq_len, batch_size, embed_dim) = x.shape - num_heads = attn_weights.shape[0] - assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) - - x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) - # now x: (num_heads, batch_size, seq_len, head_dim) - x = torch.matmul(attn_weights, x) - # now x: (num_heads, batch_size, seq_len, head_dim) - x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) - - y = self.identity2(y) - x = x * y - x = self.identity3(x) - - x = self.out_proj(x) - x = self.whiten2(x) - return x - - def streaming_forward( - self, - x: Tensor, - attn_weights: Tensor, - cached_x: Tensor, - left_context_len: int, - ) -> Tuple[Tensor, Tensor]: - """. - Args: - x: a Tensor of shape (seq_len, batch_size, num_channels) - attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) - cached_x: left context, a Tensor of shape - (num_heads, batch_size, left_context_len, head_dim) - left_context_len: number of left context frames. - Returns: - - a Tensor with the same shape as x - - updated left context with same shape as cached_x - """ - x = self.in_proj(x) - - (seq_len, batch_size, _) = x.shape - hidden_channels = self.hidden_channels - - s, x, y = x.chunk(3, dim=-1) - - # s will go through tanh. - s = self.tanh(s) - - s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) - x = x * s - - (seq_len, batch_size, embed_dim) = x.shape - num_heads = attn_weights.shape[0] - assert attn_weights.shape == (num_heads, batch_size, seq_len, left_context_len + seq_len) - - x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) - # now x: (num_heads, batch_size, seq_len, head_dim) - - # Pad cached tensor - assert cached_x.shape[2] == left_context_len, (cached_x.shape[2], left_context_len) - x_pad = torch.cat([cached_x, x], dim=2) - # Update cached tensor - cached_x = x_pad[:, :, -left_context_len:, :] - - x = torch.matmul(attn_weights, x_pad) - # now x: (num_heads, batch_size, seq_len, head_dim) - x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) - - x = x * y - - x = self.out_proj(x) - return x, cached_x - - -class ConvolutionModule(nn.Module): - """ConvolutionModule in Zipformer2 model. - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py - - Args: - channels (int): The number of channels of conv layers. - kernel_size (int): Kernerl size of conv layers. - bias (bool): Whether to use bias in conv layers (default=True). - - """ - def __init__( - self, channels: int, kernel_size: int, causal: bool, - ) -> None: - """Construct a ConvolutionModule object.""" - super(ConvolutionModule, self).__init__() - # kernerl_size should be a odd number for 'SAME' padding - assert (kernel_size - 1) % 2 == 0 - - bottleneck_dim = channels - self.causal = causal - - self.in_proj = nn.Linear( - channels, 2 * bottleneck_dim, - ) - # the gradients on in_proj are a little noisy, likely to do with the - # sigmoid in glu. - - # after in_proj we put x through a gated linear unit (nn.functional.glu). - # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, - # but sometimes, for some reason, for layer 0 the rms ends up being very large, - # between 50 and 100 for different channels. This will cause very peaky and - # sparse derivatives for the sigmoid gating function, which will tend to make - # the loss function not learn effectively. (for most layers the average absolute values - # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, - # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different - # layers, which likely breaks down as 0.5 for the "linear" half and - # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we - # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, - # it will be in a better position to start learning something, i.e. to latch onto - # the correct range. - self.balancer1 = Balancer( - bottleneck_dim, channel_dim=-1, - min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)), - max_positive=1.0, - min_abs=1.5, - max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0), - ) - - self.activation1 = Identity() # for diagnostics - - self.sigmoid = nn.Sigmoid() - - self.activation2 = Identity() # for diagnostics - - assert kernel_size % 2 == 1 - - self.depthwise_conv = ChunkCausalDepthwiseConv1d( - channels=bottleneck_dim, - kernel_size=kernel_size) if causal else nn.Conv1d( - in_channels=bottleneck_dim, - out_channels=bottleneck_dim, - groups=bottleneck_dim, - kernel_size=kernel_size, - padding=kernel_size // 2) - - self.balancer2 = Balancer( - bottleneck_dim, channel_dim=1, - min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)), - max_positive=1.0, - min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)), - max_abs=10.0, - ) - - self.whiten = Whiten(num_groups=1, - whitening_limit=_whitening_schedule(7.5), - prob=(0.025, 0.25), - grad_scale=0.01) - - self.out_proj = ActivationDropoutAndLinear( - bottleneck_dim, channels, activation='SwooshR', - dropout_p=0.0, initial_scale=0.05, - ) - - def forward( - self, - x: Tensor, - src_key_padding_mask: Optional[Tensor] = None, - chunk_size: int = -1, - ) -> Tensor: - """Compute convolution module. - - Args: - x: Input tensor (#time, batch, channels). - src_key_padding_mask: the mask for the src keys per batch (optional): - (batch, #time), contains True in masked positions. - - Returns: - Tensor: Output tensor (#time, batch, channels). - - """ - - x = self.in_proj(x) # (time, batch, 2*channels) - - x, s = x.chunk(2, dim=-1) - s = self.balancer1(s) - s = self.sigmoid(s) - x = self.activation1(x) # identity. - x = x * s - x = self.activation2(x) # identity - - # (time, batch, channels) - - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - - if src_key_padding_mask is not None: - x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) - - if not torch.jit.is_scripting() and chunk_size >= 0: - # Not support exporting a model for simulated streaming decoding - assert self.causal, "Must initialize model with causal=True if you use chunk_size" - x = self.depthwise_conv(x, chunk_size=chunk_size) - else: - x = self.depthwise_conv(x) - - x = self.balancer2(x) - x = x.permute(2, 0, 1) # (time, batch, channels) - - x = self.whiten(x) # (time, batch, channels) - x = self.out_proj(x) # (time, batch, channels) - - return x - - def streaming_forward( - self, - x: Tensor, - cache: Tensor, - src_key_padding_mask: Tensor, - ) -> Tuple[Tensor, Tensor]: - """Compute convolution module in streaming forward mode. - - Args: - x: Input tensor (#time, batch, channels). - cache: cached left context for depthwise_conv of shape - (#batch, channels, left_pad) - src_key_padding_mask: the mask for the src keys per batch (optional): - (batch, #time), contains True in masked positions. - - Returns: - - Output tensor (#time, batch, channels). - - Updated cache (#batch, channels, left_pad) - """ - - x = self.in_proj(x) # (time, batch, 2*channels) - - x, s = x.chunk(2, dim=-1) - s = self.sigmoid(s) - x = x * s - # (time, batch, channels) - - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - - if src_key_padding_mask is not None: - x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) - - x, cache = self.depthwise_conv.streaming_forward(x, cache=cache) - - x = x.permute(2, 0, 1) # (time, batch, channels) - - x = self.out_proj(x) # (time, batch, channels) - - return x, cache - - -class ScalarMultiply(nn.Module): - def __init__(self, scale: float): - super().__init__() - self.scale = scale - - def forward(self, x): - return x * self.scale - - -def _test_zipformer_main(causal: bool = False): - batch_size = 5 - seq_len = 20 - # Just make sure the forward pass runs. - - c = Zipformer2( - encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4), - causal=causal, - chunk_size=(4,) if causal else (-1,), - left_context_frames=(64,) - ) - batch_size = 5 - seq_len = 20 - # Just make sure the forward pass runs. - f = c( - torch.randn(seq_len, batch_size, 64), - torch.full((batch_size,), seq_len, dtype=torch.int64), - ) - f[0].sum().backward() - c.eval() - f = c( - torch.randn(seq_len, batch_size, 64), - torch.full((batch_size,), seq_len, dtype=torch.int64), - ) - f # to remove flake8 warnings - - -if __name__ == "__main__": - logging.getLogger().setLevel(logging.INFO) - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - _test_zipformer_main(False) - _test_zipformer_main(True) diff --git a/egs/iwslt22_ta/ASR/zipformer/zipformer.py b/egs/iwslt22_ta/ASR/zipformer/zipformer.py new file mode 120000 index 000000000..23011dda7 --- /dev/null +++ b/egs/iwslt22_ta/ASR/zipformer/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ST/README.md b/egs/iwslt22_ta/ST/README.md index dd085373b..ada7d0453 100644 --- a/egs/iwslt22_ta/ST/README.md +++ b/egs/iwslt22_ta/ST/README.md @@ -14,15 +14,15 @@ https://aclanthology.org/2022.iwslt-1.10/. | Decoding method | dev Bleu | test Bleu | comment | |------------------------------------|------------|------------|------------------------------------------| -| modified beam search | 11.1 | 9.2 | --epoch 20, --avg 10, beam(10), pruned range 5 | +| modified beam search | 11.1 | 9.2 | --epoch 20, --avg 13, beam(10), pruned range 5 | ## Zipformer Performance Record (after 20 epochs) | Decoding method | dev Bleu | test Bleu | comment | |------------------------------------|------------|------------|------------------------------------------| -| modified beam search | 14.7 | 12.4 | --epoch 20, --avg 10, beam(10),pruned range 5 | -| modified beam search | 15.5 | 13 | --epoch 20, --avg 10, beam(20),pruned range 5 | -| modified beam search | 17.6 | 14.8 | --epoch 20, --avg 10, beam(10), pruned range 10 | +| modified beam search | 14.7 | 12.4 | --epoch 20, --avg 13, beam(10),pruned range 5 | +| modified beam search | 15.5 | 13 | --epoch 20, --avg 13, beam(20),pruned range 5 | +| modified beam search | 17.9 | 14.9 | --epoch 20, --avg 13, beam(20), pruned range 10 | See [RESULTS](/egs/iwslt_ta/ST/RESULTS.md) for details. diff --git a/egs/iwslt22_ta/ST/RESULTS.md b/egs/iwslt22_ta/ST/RESULTS.md index c33c1234d..532993fae 100644 --- a/egs/iwslt22_ta/ST/RESULTS.md +++ b/egs/iwslt22_ta/ST/RESULTS.md @@ -17,7 +17,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" -./pruned_transducer_stateless5/train_st.py \ +./pruned_transducer_stateless5/train.py \ --world-size 4 \ --num-epochs 20 \ --start-epoch 1 \ @@ -34,11 +34,11 @@ The decoding command is: ``` for method in modified_beam_search; do for epoch in 15 20; do - ./pruned_transducer_stateless5/decode_st.py \ + ./pruned_transducer_stateless5/decode.py \ --epoch $epoch \ --beam-size 20 \ --avg 10 \ - --exp-dir ./pruned_transducer_stateless5/exp_st_single_task2 \ + --exp-dir ./pruned_transducer_stateless5/exp_st \ --max-duration 300 \ --decoding-method $method \ --max-sym-per-frame 1 \ @@ -75,21 +75,23 @@ To reproduce the above result, use the following commands for training: # ST medium model 42.5M prune-range 10 ``` - ./zipformer/train_st.py \ - --world-size 4 \ - --num-epochs 20 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp-st-medium-prun10 \ - --causal 0 \ - --num-encoder-layers 2,2,2,2,2,2 \ - --feedforward-dim 512,768,1024,1536,1024,768 \ - --encoder-dim 192,256,384,512,384,256 \ - --encoder-unmasked-dim 192,192,256,256,256,192 \ - --max-duration 300 \ - --context-size 2 \ - --prune-range 10 - --prune-range 10 + ./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp-st-medium-nohat800s-warmstep8k_baselr05_lrbatch5k_lrepoch6 \ + --causal 0 \ + --num-encoder-layers 2,2,2,2,2,2 \ + --feedforward-dim 512,768,1024,1536,1024,768 \ + --encoder-dim 192,256,384,512,384,256 \ + --encoder-unmasked-dim 192,192,256,256,256,192 \ + --max-duration 800 \ + --prune-range 10 \ + --warm-step 8000 \ + --lr-epochs 6 \ + --base-lr 0.055 \ + --use-hat False ``` @@ -101,19 +103,19 @@ The decoding command is: ``` for method in modified_beam_search; do for epoch in 15 20; do - ./zipformer/decode_st.py \ + ./zipformer/decode.py \ --epoch $epoch \ --beam-size 20 \ - --avg 10 \ + --avg 13 \ --exp-dir ./zipformer/exp-st-medium-prun10 \ --max-duration 800 \ --decoding-method $method \ - --num-encoder-layers 2,2,2,2,2,2 \ - --feedforward-dim 512,768,1024,1536,1024,768 \ - --encoder-dim 192,256,384,512,384,256 \ - --encoder-unmasked-dim 192,192,256,256,256,192 \ - --context-size 2 \ - --use-averaged-model true + --num-encoder-layers 2,2,2,2,2,2 \ + --feedforward-dim 512,768,1024,1536,1024,768 \ + --encoder-dim 192,256,384,512,384,256 \ + --encoder-unmasked-dim 192,192,256,256,256,192 \ + --context-size 2 \ + --use-averaged-model true done done ``` diff --git a/egs/iwslt22_ta/ST/local/compile_hlg.py b/egs/iwslt22_ta/ST/local/compile_hlg.py deleted file mode 100755 index 9a35750e0..000000000 --- a/egs/iwslt22_ta/ST/local/compile_hlg.py +++ /dev/null @@ -1,159 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This script takes as input lang_dir and generates HLG from - - - H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt - - L, the lexicon, built from lang_dir/L_disambig.pt - - Caution: We use a lexicon that contains disambiguation symbols - - - G, the LM, built from data/lm/G_3_gram.fst.txt - -The generated HLG is saved in $lang_dir/HLG.pt -""" -import argparse -import logging -from pathlib import Path - -import k2 -import torch - -from icefall.lexicon import Lexicon - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", - type=str, - help="""Input and output directory. - """, - ) - - return parser.parse_args() - - -def compile_HLG(lang_dir: str) -> k2.Fsa: - """ - Args: - lang_dir: - The language directory, e.g., data/lang_phone or data/lang_bpe_5000. - - Return: - An FSA representing HLG. - """ - lexicon = Lexicon(lang_dir) - max_token_id = max(lexicon.tokens) - logging.info(f"Building ctc_topo. max_token_id: {max_token_id}") - H = k2.ctc_topo(max_token_id) - L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) - - if Path("data/lm/G_3_gram.pt").is_file(): - logging.info("Loading pre-compiled G_3_gram") - d = torch.load("data/lm/G_3_gram.pt") - G = k2.Fsa.from_dict(d) - else: - logging.info("Loading G_3_gram.fst.txt") - with open("data/lm/G_3_gram.fst.txt") as f: - G = k2.Fsa.from_openfst(f.read(), acceptor=False) - torch.save(G.as_dict(), "data/lm/G_3_gram.pt") - - first_token_disambig_id = lexicon.token_table["#0"] - first_word_disambig_id = lexicon.word_table["#0"] - - L = k2.arc_sort(L) - G = k2.arc_sort(G) - - logging.info("Intersecting L and G") - LG = k2.compose(L, G) - logging.info(f"LG shape: {LG.shape}") - - logging.info("Connecting LG") - LG = k2.connect(LG) - logging.info(f"LG shape after k2.connect: {LG.shape}") - - logging.info(type(LG.aux_labels)) - logging.info("Determinizing LG") - - LG = k2.determinize(LG) - logging.info(type(LG.aux_labels)) - - logging.info("Connecting LG after k2.determinize") - LG = k2.connect(LG) - - logging.info("Removing disambiguation symbols on LG") - - LG.labels[LG.labels >= first_token_disambig_id] = 0 - # See https://github.com/k2-fsa/k2/issues/874 - # for why we need to set LG.properties to None - LG.__dict__["_properties"] = None - - assert isinstance(LG.aux_labels, k2.RaggedTensor) - LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0 - - LG = k2.remove_epsilon(LG) - logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") - - LG = k2.connect(LG) - LG.aux_labels = LG.aux_labels.remove_values_eq(0) - - logging.info("Arc sorting LG") - LG = k2.arc_sort(LG) - - logging.info("Composing H and LG") - # CAUTION: The name of the inner_labels is fixed - # to `tokens`. If you want to change it, please - # also change other places in icefall that are using - # it. - HLG = k2.compose(H, LG, inner_labels="tokens") - - logging.info("Connecting LG") - HLG = k2.connect(HLG) - - logging.info("Arc sorting LG") - HLG = k2.arc_sort(HLG) - logging.info(f"HLG.shape: {HLG.shape}") - - return HLG - - -def main(): - args = get_args() - lang_dir = Path(args.lang_dir) - - if (lang_dir / "HLG.pt").is_file(): - logging.info(f"{lang_dir}/HLG.pt already exists - skipping") - return - - logging.info(f"Processing {lang_dir}") - - HLG = compile_HLG(lang_dir) - logging.info(f"Saving HLG.pt to {lang_dir}") - torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt") - - -if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/iwslt22_ta/ST/local/compute_fbank_gpu.py b/egs/iwslt22_ta/ST/local/compute_fbank_gpu.py index 84e17addb..05ed0a74a 100755 --- a/egs/iwslt22_ta/ST/local/compute_fbank_gpu.py +++ b/egs/iwslt22_ta/ST/local/compute_fbank_gpu.py @@ -45,8 +45,6 @@ from lhotse.features.kaldifeat import ( # it wastes a lot of CPU and slow things down. # Do this outside of main() in case it needs to take effect # even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) def get_args(): parser = argparse.ArgumentParser() @@ -91,7 +89,7 @@ def compute_fbank_gpu(args): "dev", ) manifests = read_manifests_if_cached( - prefix="iwslt", dataset_parts=dataset_parts, output_dir=src_dir + prefix="iwslt-ta", dataset_parts=dataset_parts, output_dir=src_dir ) assert manifests is not None diff --git a/egs/iwslt22_ta/ST/local/compute_fbank_musan.py b/egs/iwslt22_ta/ST/local/compute_fbank_musan.py deleted file mode 100755 index 48905de6f..000000000 --- a/egs/iwslt22_ta/ST/local/compute_fbank_musan.py +++ /dev/null @@ -1,109 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the musan dataset. -It looks for manifests in the directory data/manifests. - -The generated fbank features are saved in data/fbank. -""" - -import logging -import os -from pathlib import Path - -import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, MonoCut, combine -from lhotse.recipes.utils import read_manifests_if_cached - -from icefall.utils import get_executor - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def is_cut_long(c: MonoCut) -> bool: - return c.duration > 5 - - -def compute_fbank_musan(): - src_dir = Path("data/manifests") - output_dir = Path("data/fbank") - num_jobs = min(30, os.cpu_count()) - num_mel_bins = 80 - - dataset_parts = ( - "music", - "speech", - "noise", - ) - prefix = "musan" - suffix = "jsonl.gz" - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, - output_dir=src_dir, - prefix=prefix, - suffix=suffix, - ) - assert manifests is not None - - assert len(manifests) == len(dataset_parts), ( - len(manifests), - len(dataset_parts), - list(manifests.keys()), - dataset_parts, - ) - - musan_cuts_path = output_dir / "musan_cuts.jsonl.gz" - - if musan_cuts_path.is_file(): - logging.info(f"{musan_cuts_path} already exists - skipping") - return - - logging.info("Extracting features for Musan") - - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) - - with get_executor() as ex: # Initialize the executor only once. - # create chunks of Musan with duration 5 - 10 seconds - musan_cuts = ( - CutSet.from_manifests( - recordings=combine(part["recordings"] for part in manifests.values()) - ) - .cut_into_windows(10.0) - .filter(is_cut_long) - .compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/musan_feats", - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - ) - musan_cuts.to_file(musan_cuts_path) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - compute_fbank_musan() \ No newline at end of file diff --git a/egs/iwslt22_ta/ST/local/compute_fbank_musan.py b/egs/iwslt22_ta/ST/local/compute_fbank_musan.py new file mode 120000 index 000000000..5833f2484 --- /dev/null +++ b/egs/iwslt22_ta/ST/local/compute_fbank_musan.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ST/local/display_manifest_statistics.py b/egs/iwslt22_ta/ST/local/display_manifest_statistics.py deleted file mode 100755 index d3e224905..000000000 --- a/egs/iwslt22_ta/ST/local/display_manifest_statistics.py +++ /dev/null @@ -1,97 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file displays duration statistics of utterances in a manifest. -You can use the displayed value to choose minimum/maximum duration -to remove short and long utterances during the training. - -See the function `remove_short_and_long_utt()` in transducer/train.py -for usage. -""" - - -from lhotse import load_manifest - - -def main(): - # path = "./data/fbank/cuts_train.jsonl.gz" - path = "./data/fbank/cuts_dev.jsonl.gz" - # path = "./data/fbank/cuts_test.jsonl.gz" - - cuts = load_manifest(path) - cuts.describe() - - -if __name__ == "__main__": - main() - -""" -# train - -Cuts count: 1125309 -Total duration (hours): 3403.9 -Speech duration (hours): 3403.9 (100.0%) -*** -Duration statistics (seconds): -mean 10.9 -std 10.1 -min 0.2 -25% 5.2 -50% 7.8 -75% 12.7 -99% 52.0 -99.5% 65.1 -99.9% 99.5 -max 228.9 - - -# test -Cuts count: 5365 -Total duration (hours): 9.6 -Speech duration (hours): 9.6 (100.0%) -*** -Duration statistics (seconds): -mean 6.4 -std 1.5 -min 1.6 -25% 5.3 -50% 6.5 -75% 7.6 -99% 9.5 -99.5% 9.7 -99.9% 10.3 -max 12.4 - -# dev -Cuts count: 5002 -Total duration (hours): 8.5 -Speech duration (hours): 8.5 (100.0%) -*** -Duration statistics (seconds): -mean 6.1 -std 1.7 -min 1.5 -25% 4.8 -50% 6.2 -75% 7.4 -99% 9.5 -99.5% 9.7 -99.9% 10.1 -max 20.3 - -""" diff --git a/egs/iwslt22_ta/ST/local/display_manifest_statistics.py b/egs/iwslt22_ta/ST/local/display_manifest_statistics.py new file mode 120000 index 000000000..e99e43515 --- /dev/null +++ b/egs/iwslt22_ta/ST/local/display_manifest_statistics.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/display_manifest_statistics.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ST/local/download_lm.py b/egs/iwslt22_ta/ST/local/download_lm.py deleted file mode 100755 index 94d23afed..000000000 --- a/egs/iwslt22_ta/ST/local/download_lm.py +++ /dev/null @@ -1,97 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file downloads the following LibriSpeech LM files: - - - 3-gram.pruned.1e-7.arpa.gz - - 4-gram.arpa.gz - - librispeech-vocab.txt - - librispeech-lexicon.txt - -from http://www.openslr.org/resources/11 -and save them in the user provided directory. - -Files are not re-downloaded if they already exist. - -Usage: - ./local/download_lm.py --out-dir ./download/lm -""" - -import argparse -import gzip -import logging -import os -import shutil -from pathlib import Path - -from lhotse.utils import urlretrieve_progress -from tqdm.auto import tqdm - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--out-dir", type=str, help="Output directory.") - - args = parser.parse_args() - return args - - -def main(out_dir: str): - url = "http://www.openslr.org/resources/11" - out_dir = Path(out_dir) - - files_to_download = ( - "3-gram.pruned.1e-7.arpa.gz", - "4-gram.arpa.gz", - "librispeech-vocab.txt", - "librispeech-lexicon.txt", - ) - - for f in tqdm(files_to_download, desc="Downloading LibriSpeech LM files"): - filename = out_dir / f - if filename.is_file() is False: - urlretrieve_progress( - f"{url}/{f}", - filename=filename, - desc=f"Downloading {filename}", - ) - else: - logging.info(f"{filename} already exists - skipping") - - if ".gz" in str(filename): - unzipped = Path(os.path.splitext(filename)[0]) - if unzipped.is_file() is False: - with gzip.open(filename, "rb") as f_in: - with open(unzipped, "wb") as f_out: - shutil.copyfileobj(f_in, f_out) - else: - logging.info(f"{unzipped} already exist - skipping") - - -if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) - - logging.basicConfig(format=formatter, level=logging.INFO) - - args = get_args() - logging.info(f"out_dir: {args.out_dir}") - - main(out_dir=args.out_dir) diff --git a/egs/iwslt22_ta/ST/local/generate_unique_lexicon.py b/egs/iwslt22_ta/ST/local/generate_unique_lexicon.py deleted file mode 100755 index 566c0743d..000000000 --- a/egs/iwslt22_ta/ST/local/generate_unique_lexicon.py +++ /dev/null @@ -1,100 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file takes as input a lexicon.txt and output a new lexicon, -in which each word has a unique pronunciation. - -The way to do this is to keep only the first pronunciation of a word -in lexicon.txt. -""" - - -import argparse -import logging -from pathlib import Path -from typing import List, Tuple - -from icefall.lexicon import read_lexicon, write_lexicon - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", - type=str, - help="""Input and output directory. - It should contain a file lexicon.txt. - This file will generate a new file uniq_lexicon.txt - in it. - """, - ) - - return parser.parse_args() - - -def filter_multiple_pronunications( - lexicon: List[Tuple[str, List[str]]] -) -> List[Tuple[str, List[str]]]: - """Remove multiple pronunciations of words from a lexicon. - - If a word has more than one pronunciation in the lexicon, only - the first one is kept, while other pronunciations are removed - from the lexicon. - - Args: - lexicon: - The input lexicon, containing a list of (word, [p1, p2, ..., pn]), - where "p1, p2, ..., pn" are the pronunciations of the "word". - Returns: - Return a new lexicon where each word has a unique pronunciation. - """ - seen = set() - ans = [] - - for word, tokens in lexicon: - if word in seen: - continue - seen.add(word) - ans.append((word, tokens)) - return ans - - -def main(): - args = get_args() - lang_dir = Path(args.lang_dir) - - lexicon_filename = lang_dir / "lexicon.txt" - - in_lexicon = read_lexicon(lexicon_filename) - - out_lexicon = filter_multiple_pronunications(in_lexicon) - - write_lexicon(lang_dir / "uniq_lexicon.txt", out_lexicon) - - logging.info(f"Number of entries in lexicon.txt: {len(in_lexicon)}") - logging.info(f"Number of entries in uniq_lexicon.txt: {len(out_lexicon)}") - - -if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/iwslt22_ta/ST/local/generate_unique_lexicon.py b/egs/iwslt22_ta/ST/local/generate_unique_lexicon.py new file mode 120000 index 000000000..c0aea1403 --- /dev/null +++ b/egs/iwslt22_ta/ST/local/generate_unique_lexicon.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/generate_unique_lexicon.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ST/local/prepare_lang.py b/egs/iwslt22_ta/ST/local/prepare_lang.py deleted file mode 100755 index 1f7120c99..000000000 --- a/egs/iwslt22_ta/ST/local/prepare_lang.py +++ /dev/null @@ -1,414 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This script takes as input a lexicon file "data/lang_phone/lexicon.txt" -consisting of words and tokens (i.e., phones) and does the following: - -1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt - -2. Generate tokens.txt, the token table mapping a token to a unique integer. - -3. Generate words.txt, the word table mapping a word to a unique integer. - -4. Generate L.pt, in k2 format. It can be loaded by - - d = torch.load("L.pt") - lexicon = k2.Fsa.from_dict(d) - -5. Generate L_disambig.pt, in k2 format. -""" -import argparse -import math -from collections import defaultdict -from pathlib import Path -from typing import Any, Dict, List, Tuple - -import k2 -import torch - -from icefall.lexicon import read_lexicon, write_lexicon -from icefall.utils import str2bool - -Lexicon = List[Tuple[str, List[str]]] - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", - type=str, - help="""Input and output directory. - It should contain a file lexicon.txt. - Generated files by this script are saved into this directory. - """, - ) - - parser.add_argument( - "--debug", - type=str2bool, - default=False, - help="""True for debugging, which will generate - a visualization of the lexicon FST. - - Caution: If your lexicon contains hundreds of thousands - of lines, please set it to False! - """, - ) - - return parser.parse_args() - - -def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: - """Write a symbol to ID mapping to a file. - - Note: - No need to implement `read_mapping` as it can be done - through :func:`k2.SymbolTable.from_file`. - - Args: - filename: - Filename to save the mapping. - sym2id: - A dict mapping symbols to IDs. - Returns: - Return None. - """ - with open(filename, "w", encoding="utf-8") as f: - for sym, i in sym2id.items(): - f.write(f"{sym} {i}\n") - - -def get_tokens(lexicon: Lexicon) -> List[str]: - """Get tokens from a lexicon. - - Args: - lexicon: - It is the return value of :func:`read_lexicon`. - Returns: - Return a list of unique tokens. - """ - ans = set() - for _, tokens in lexicon: - ans.update(tokens) - sorted_ans = sorted(list(ans)) - return sorted_ans - - -def get_words(lexicon: Lexicon) -> List[str]: - """Get words from a lexicon. - - Args: - lexicon: - It is the return value of :func:`read_lexicon`. - Returns: - Return a list of unique words. - """ - ans = set() - for word, _ in lexicon: - ans.add(word) - sorted_ans = sorted(list(ans)) - return sorted_ans - - -def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]: - """It adds pseudo-token disambiguation symbols #1, #2 and so on - at the ends of tokens to ensure that all pronunciations are different, - and that none is a prefix of another. - - See also add_lex_disambig.pl from kaldi. - - Args: - lexicon: - It is returned by :func:`read_lexicon`. - Returns: - Return a tuple with two elements: - - - The output lexicon with disambiguation symbols - - The ID of the max disambiguation symbol that appears - in the lexicon - """ - - # (1) Work out the count of each token-sequence in the - # lexicon. - count = defaultdict(int) - for _, tokens in lexicon: - count[" ".join(tokens)] += 1 - - # (2) For each left sub-sequence of each token-sequence, note down - # that it exists (for identifying prefixes of longer strings). - issubseq = defaultdict(int) - for _, tokens in lexicon: - tokens = tokens.copy() - tokens.pop() - while tokens: - issubseq[" ".join(tokens)] = 1 - tokens.pop() - - # (3) For each entry in the lexicon: - # if the token sequence is unique and is not a - # prefix of another word, no disambig symbol. - # Else output #1, or #2, #3, ... if the same token-seq - # has already been assigned a disambig symbol. - ans = [] - - # We start with #1 since #0 has its own purpose - first_allowed_disambig = 1 - max_disambig = first_allowed_disambig - 1 - last_used_disambig_symbol_of = defaultdict(int) - - for word, tokens in lexicon: - tokenseq = " ".join(tokens) - assert tokenseq != "" - if issubseq[tokenseq] == 0 and count[tokenseq] == 1: - ans.append((word, tokens)) - continue - - cur_disambig = last_used_disambig_symbol_of[tokenseq] - if cur_disambig == 0: - cur_disambig = first_allowed_disambig - else: - cur_disambig += 1 - - if cur_disambig > max_disambig: - max_disambig = cur_disambig - last_used_disambig_symbol_of[tokenseq] = cur_disambig - tokenseq += f" #{cur_disambig}" - ans.append((word, tokenseq.split())) - return ans, max_disambig - - -def generate_id_map(symbols: List[str]) -> Dict[str, int]: - """Generate ID maps, i.e., map a symbol to a unique ID. - - Args: - symbols: - A list of unique symbols. - Returns: - A dict containing the mapping between symbols and IDs. - """ - return {sym: i for i, sym in enumerate(symbols)} - - -def add_self_loops( - arcs: List[List[Any]], disambig_token: int, disambig_word: int -) -> List[List[Any]]: - """Adds self-loops to states of an FST to propagate disambiguation symbols - through it. They are added on each state with non-epsilon output symbols - on at least one arc out of the state. - - See also fstaddselfloops.pl from Kaldi. One difference is that - Kaldi uses OpenFst style FSTs and it has multiple final states. - This function uses k2 style FSTs and it does not need to add self-loops - to the final state. - - The input label of a self-loop is `disambig_token`, while the output - label is `disambig_word`. - - Args: - arcs: - A list-of-list. The sublist contains - `[src_state, dest_state, label, aux_label, score]` - disambig_token: - It is the token ID of the symbol `#0`. - disambig_word: - It is the word ID of the symbol `#0`. - - Return: - Return new `arcs` containing self-loops. - """ - states_needs_self_loops = set() - for arc in arcs: - src, dst, ilabel, olabel, score = arc - if olabel != 0: - states_needs_self_loops.add(src) - - ans = [] - for s in states_needs_self_loops: - ans.append([s, s, disambig_token, disambig_word, 0]) - - return arcs + ans - - -def lexicon_to_fst( - lexicon: Lexicon, - token2id: Dict[str, int], - word2id: Dict[str, int], - sil_token: str = "SIL", - sil_prob: float = 0.5, - need_self_loops: bool = False, -) -> k2.Fsa: - """Convert a lexicon to an FST (in k2 format) with optional silence at - the beginning and end of each word. - - Args: - lexicon: - The input lexicon. See also :func:`read_lexicon` - token2id: - A dict mapping tokens to IDs. - word2id: - A dict mapping words to IDs. - sil_token: - The silence token. - sil_prob: - The probability for adding a silence at the beginning and end - of the word. - need_self_loops: - If True, add self-loop to states with non-epsilon output symbols - on at least one arc out of the state. The input label for this - self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. - Returns: - Return an instance of `k2.Fsa` representing the given lexicon. - """ - assert sil_prob > 0.0 and sil_prob < 1.0 - # CAUTION: we use score, i.e, negative cost. - sil_score = math.log(sil_prob) - no_sil_score = math.log(1.0 - sil_prob) - - start_state = 0 - loop_state = 1 # words enter and leave from here - sil_state = 2 # words terminate here when followed by silence; this state - # has a silence transition to loop_state. - # the next un-allocated state, will be incremented as we go. - next_state = 3 - arcs = [] - - assert token2id[""] == 0 - assert word2id[""] == 0 - - eps = 0 - - sil_token = token2id[sil_token] - - arcs.append([start_state, loop_state, eps, eps, no_sil_score]) - arcs.append([start_state, sil_state, eps, eps, sil_score]) - arcs.append([sil_state, loop_state, sil_token, eps, 0]) - - for word, tokens in lexicon: - assert len(tokens) > 0, f"{word} has no pronunciations" - cur_state = loop_state - - word = word2id[word] - tokens = [token2id[i] for i in tokens] - - for i in range(len(tokens) - 1): - w = word if i == 0 else eps - arcs.append([cur_state, next_state, tokens[i], w, 0]) - - cur_state = next_state - next_state += 1 - - # now for the last token of this word - # It has two out-going arcs, one to the loop state, - # the other one to the sil_state. - i = len(tokens) - 1 - w = word if i == 0 else eps - arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score]) - arcs.append([cur_state, sil_state, tokens[i], w, sil_score]) - - if need_self_loops: - disambig_token = token2id["#0"] - disambig_word = word2id["#0"] - arcs = add_self_loops( - arcs, - disambig_token=disambig_token, - disambig_word=disambig_word, - ) - - final_state = next_state - arcs.append([loop_state, final_state, -1, -1, 0]) - arcs.append([final_state]) - - arcs = sorted(arcs, key=lambda arc: arc[0]) - arcs = [[str(i) for i in arc] for arc in arcs] - arcs = [" ".join(arc) for arc in arcs] - arcs = "\n".join(arcs) - - fsa = k2.Fsa.from_str(arcs, acceptor=False) - return fsa - - -def main(): - args = get_args() - lang_dir = Path(args.lang_dir) - lexicon_filename = lang_dir / "lexicon.txt" - sil_token = "SIL" - sil_prob = 0.5 - - lexicon = read_lexicon(lexicon_filename) - tokens = get_tokens(lexicon) - words = get_words(lexicon) - - lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) - - for i in range(max_disambig + 1): - disambig = f"#{i}" - assert disambig not in tokens - tokens.append(f"#{i}") - - assert "" not in tokens - tokens = [""] + tokens - - assert "" not in words - assert "#0" not in words - assert "" not in words - assert "" not in words - - words = [""] + words + ["#0", "", ""] - - token2id = generate_id_map(tokens) - word2id = generate_id_map(words) - - write_mapping(lang_dir / "tokens.txt", token2id) - write_mapping(lang_dir / "words.txt", word2id) - write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) - - L = lexicon_to_fst( - lexicon, - token2id=token2id, - word2id=word2id, - sil_token=sil_token, - sil_prob=sil_prob, - ) - - L_disambig = lexicon_to_fst( - lexicon_disambig, - token2id=token2id, - word2id=word2id, - sil_token=sil_token, - sil_prob=sil_prob, - need_self_loops=True, - ) - torch.save(L.as_dict(), lang_dir / "L.pt") - torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") - - if args.debug: - labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt") - aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt") - - L.labels_sym = labels_sym - L.aux_labels_sym = aux_labels_sym - L.draw(f"{lang_dir / 'L.svg'}", title="L.pt") - - L_disambig.labels_sym = labels_sym - L_disambig.aux_labels_sym = aux_labels_sym - L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt") - - -if __name__ == "__main__": - main() diff --git a/egs/iwslt22_ta/ST/local/prepare_lang.py b/egs/iwslt22_ta/ST/local/prepare_lang.py new file mode 120000 index 000000000..747f2ab39 --- /dev/null +++ b/egs/iwslt22_ta/ST/local/prepare_lang.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ST/local/prepare_lang_bpe.py b/egs/iwslt22_ta/ST/local/prepare_lang_bpe.py deleted file mode 100755 index 24104581f..000000000 --- a/egs/iwslt22_ta/ST/local/prepare_lang_bpe.py +++ /dev/null @@ -1,255 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) - -""" - -This script takes as input `lang_dir`, which should contain:: - - - lang_dir/bpe.model, - - lang_dir/words.txt - -and generates the following files in the directory `lang_dir`: - - - lexicon.txt - - lexicon_disambig.txt - - L.pt - - L_disambig.pt - - tokens.txt -""" - -import argparse -from pathlib import Path -from typing import Dict, List, Tuple - -import k2 -import sentencepiece as spm -import torch -from prepare_lang import ( - Lexicon, - add_disambig_symbols, - add_self_loops, - write_lexicon, - write_mapping, -) - -from icefall.utils import str2bool -import pdb - - -def lexicon_to_fst_no_sil( - lexicon: Lexicon, - token2id: Dict[str, int], - word2id: Dict[str, int], - need_self_loops: bool = False, -) -> k2.Fsa: - """Convert a lexicon to an FST (in k2 format). - - Args: - lexicon: - The input lexicon. See also :func:`read_lexicon` - token2id: - A dict mapping tokens to IDs. - word2id: - A dict mapping words to IDs. - need_self_loops: - If True, add self-loop to states with non-epsilon output symbols - on at least one arc out of the state. The input label for this - self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. - Returns: - Return an instance of `k2.Fsa` representing the given lexicon. - """ - loop_state = 0 # words enter and leave from here - next_state = 1 # the next un-allocated state, will be incremented as we go - - arcs = [] - - # The blank symbol is defined in local/train_bpe_model.py - assert token2id[""] == 0 - assert word2id[""] == 0 - - eps = 0 - - for word, pieces in lexicon: - assert len(pieces) > 0, f"{word} has no pronunciations" - cur_state = loop_state - - word = word2id[word] - pieces = [token2id[i] for i in pieces] - - for i in range(len(pieces) - 1): - w = word if i == 0 else eps - arcs.append([cur_state, next_state, pieces[i], w, 0]) - - cur_state = next_state - next_state += 1 - - # now for the last piece of this word - i = len(pieces) - 1 - w = word if i == 0 else eps - arcs.append([cur_state, loop_state, pieces[i], w, 0]) - - if need_self_loops: - disambig_token = token2id["#0"] - disambig_word = word2id["#0"] - arcs = add_self_loops( - arcs, - disambig_token=disambig_token, - disambig_word=disambig_word, - ) - - final_state = next_state - arcs.append([loop_state, final_state, -1, -1, 0]) - arcs.append([final_state]) - - arcs = sorted(arcs, key=lambda arc: arc[0]) - arcs = [[str(i) for i in arc] for arc in arcs] - arcs = [" ".join(arc) for arc in arcs] - arcs = "\n".join(arcs) - - fsa = k2.Fsa.from_str(arcs, acceptor=False) - return fsa - - -def generate_lexicon( - model_file: str, words: List[str] -) -> Tuple[Lexicon, Dict[str, int]]: - """Generate a lexicon from a BPE model. - - Args: - model_file: - Path to a sentencepiece model. - words: - A list of strings representing words. - Returns: - Return a tuple with two elements: - - A dict whose keys are words and values are the corresponding - word pieces. - - A dict representing the token symbol, mapping from tokens to IDs. - """ - sp = spm.SentencePieceProcessor() - sp.load(str(model_file)) - - words_pieces: List[List[str]] = sp.encode(words, out_type=str) - - lexicon = [] - for word, pieces in zip(words, words_pieces): - lexicon.append((word, pieces)) - - # The OOV word is - lexicon.append(("", [sp.id_to_piece(sp.unk_id())])) - - token2id: Dict[str, int] = dict() - for i in range(sp.vocab_size()): - token2id[sp.id_to_piece(i)] = i - - return lexicon, token2id - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", - type=str, - help="""Input and output directory. - It should contain the bpe.model and words.txt - """, - ) - - parser.add_argument( - "--debug", - type=str2bool, - default=False, - help="""True for debugging, which will generate - a visualization of the lexicon FST. - - Caution: If your lexicon contains hundreds of thousands - of lines, please set it to False! - - See "test/test_bpe_lexicon.py" for usage. - """, - ) - - return parser.parse_args() - - -def main(): - args = get_args() - lang_dir = Path(args.lang_dir) - model_file = lang_dir / "bpe.model" - - word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") - - words = word_sym_table.symbols - - excluded = ["", "!SIL", "", "", "#0", "", ""] - for w in excluded: - if w in words: - words.remove(w) - - lexicon, token_sym_table = generate_lexicon(model_file, words) - - lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) - - next_token_id = max(token_sym_table.values()) + 1 - for i in range(max_disambig + 1): - disambig = f"#{i}" - assert disambig not in token_sym_table - token_sym_table[disambig] = next_token_id - next_token_id += 1 - - word_sym_table.add("#0") - word_sym_table.add("") - word_sym_table.add("") - - write_mapping(lang_dir / "tokens.txt", token_sym_table) - - write_lexicon(lang_dir / "lexicon.txt", lexicon) - write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) - - L = lexicon_to_fst_no_sil( - lexicon, - token2id=token_sym_table, - word2id=word_sym_table, - ) - - L_disambig = lexicon_to_fst_no_sil( - lexicon_disambig, - token2id=token_sym_table, - word2id=word_sym_table, - need_self_loops=True, - ) - torch.save(L.as_dict(), lang_dir / "L.pt") - torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") - - if args.debug: - labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt") - aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt") - - L.labels_sym = labels_sym - L.aux_labels_sym = aux_labels_sym - L.draw(f"{lang_dir / 'L.svg'}", title="L.pt") - - L_disambig.labels_sym = labels_sym - L_disambig.aux_labels_sym = aux_labels_sym - L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt") - - -if __name__ == "__main__": - main() diff --git a/egs/iwslt22_ta/ST/local/prepare_lang_bpe.py b/egs/iwslt22_ta/ST/local/prepare_lang_bpe.py new file mode 120000 index 000000000..36b40e7fc --- /dev/null +++ b/egs/iwslt22_ta/ST/local/prepare_lang_bpe.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang_bpe.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ST/local/prepare_transcripts.py b/egs/iwslt22_ta/ST/local/prepare_transcripts.py index 46d051c9a..c4e139829 100755 --- a/egs/iwslt22_ta/ST/local/prepare_transcripts.py +++ b/egs/iwslt22_ta/ST/local/prepare_transcripts.py @@ -57,9 +57,8 @@ def main(): with open(langdirs[0] / "transcript_words.txt", 'w') as src, open(langdirs[1] / "transcript_words.txt", 'w') as tgt: for c in cuts: - #breakpoint() src_txt = c.supervisions[0].text - tgt_txt = c.supervisions[0].custom['tgt_text'] + tgt_txt = c.supervisions[0].custom['translated_text']['eng'] src.write(src_txt + '\n') tgt.write(tgt_txt + '\n') diff --git a/egs/iwslt22_ta/ST/local/test_prepare_lang.py b/egs/iwslt22_ta/ST/local/test_prepare_lang.py deleted file mode 100755 index d4cf62bba..000000000 --- a/egs/iwslt22_ta/ST/local/test_prepare_lang.py +++ /dev/null @@ -1,106 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) - -import os -import tempfile - -import k2 -from prepare_lang import ( - add_disambig_symbols, - generate_id_map, - get_phones, - get_words, - lexicon_to_fst, - read_lexicon, - write_lexicon, - write_mapping, -) - - -def generate_lexicon_file() -> str: - fd, filename = tempfile.mkstemp() - os.close(fd) - s = """ - !SIL SIL - SPN - SPN - f f - a a - foo f o o - bar b a r - bark b a r k - food f o o d - food2 f o o d - fo f o - """.strip() - with open(filename, "w") as f: - f.write(s) - return filename - - -def test_read_lexicon(filename: str): - lexicon = read_lexicon(filename) - phones = get_phones(lexicon) - words = get_words(lexicon) - print(lexicon) - print(phones) - print(words) - lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) - print(lexicon_disambig) - print("max disambig:", f"#{max_disambig}") - - phones = ["", "SIL", "SPN"] + phones - for i in range(max_disambig + 1): - phones.append(f"#{i}") - words = [""] + words - - phone2id = generate_id_map(phones) - word2id = generate_id_map(words) - - print(phone2id) - print(word2id) - - write_mapping("phones.txt", phone2id) - write_mapping("words.txt", word2id) - - write_lexicon("a.txt", lexicon) - write_lexicon("a_disambig.txt", lexicon_disambig) - - fsa = lexicon_to_fst(lexicon, phone2id=phone2id, word2id=word2id) - fsa.labels_sym = k2.SymbolTable.from_file("phones.txt") - fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") - fsa.draw("L.pdf", title="L") - - fsa_disambig = lexicon_to_fst( - lexicon_disambig, phone2id=phone2id, word2id=word2id - ) - fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") - fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") - fsa_disambig.draw("L_disambig.pdf", title="L_disambig") - - -def main(): - filename = generate_lexicon_file() - test_read_lexicon(filename) - os.remove(filename) - - -if __name__ == "__main__": - main() diff --git a/egs/iwslt22_ta/ST/local/test_prepare_lang.py b/egs/iwslt22_ta/ST/local/test_prepare_lang.py new file mode 120000 index 000000000..f0f864998 --- /dev/null +++ b/egs/iwslt22_ta/ST/local/test_prepare_lang.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/test_prepare_lang.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ST/local/train_bpe_model.py b/egs/iwslt22_ta/ST/local/train_bpe_model.py deleted file mode 100755 index bc5812810..000000000 --- a/egs/iwslt22_ta/ST/local/train_bpe_model.py +++ /dev/null @@ -1,98 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# You can install sentencepiece via: -# -# pip install sentencepiece -# -# Due to an issue reported in -# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030 -# -# Please install a version >=0.1.96 - -import argparse -import shutil -from pathlib import Path - -import sentencepiece as spm - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", - type=str, - help="""Input and output directory. - It should contain the training corpus: transcript_words.txt. - The generated bpe.model is saved to this directory. - """, - ) - - parser.add_argument( - "--transcript", - type=str, - help="Training transcript.", - ) - - parser.add_argument( - "--vocab-size", - type=int, - help="Vocabulary size for BPE training", - ) - - return parser.parse_args() - - -def main(): - args = get_args() - vocab_size = args.vocab_size - lang_dir = Path(args.lang_dir) - - model_type = "unigram" - - model_prefix = f"{lang_dir}/{model_type}_{vocab_size}" - train_text = args.transcript - character_coverage = 1.0 - input_sentence_size = 100000000 - - user_defined_symbols = ["", ""] - unk_id = len(user_defined_symbols) - # Note: unk_id is fixed to 2. - # If you change it, you should also change other - # places that are using it. - - model_file = Path(model_prefix + ".model") - if not model_file.is_file(): - spm.SentencePieceTrainer.train( - input=train_text, - vocab_size=vocab_size, - model_type=model_type, - model_prefix=model_prefix, - input_sentence_size=input_sentence_size, - character_coverage=character_coverage, - user_defined_symbols=user_defined_symbols, - unk_id=unk_id, - bos_id=-1, - eos_id=-1, - ) - - shutil.copyfile(model_file, f"{lang_dir}/bpe.model") - - -if __name__ == "__main__": - main() diff --git a/egs/iwslt22_ta/ST/local/train_bpe_model.py b/egs/iwslt22_ta/ST/local/train_bpe_model.py new file mode 120000 index 000000000..6fad36421 --- /dev/null +++ b/egs/iwslt22_ta/ST/local/train_bpe_model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/train_bpe_model.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/asr_finetune_datamodule.py b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/asr_finetune_datamodule.py deleted file mode 100644 index 13bc882b3..000000000 --- a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/asr_finetune_datamodule.py +++ /dev/null @@ -1,422 +0,0 @@ -# Copyright 2022 Amir Hussein - -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import inspect -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, Optional - -import torch -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy -from lhotse.dataset import ( - CutConcatenate, - CutMix, - DynamicBucketingSampler, - K2SpeechRecognitionDataset, - PrecomputedFeatures, - SingleCutSampler, - SpecAugment, -) -from lhotse.dataset.input_strategies import OnTheFlyFeatures -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class MGB2AsrDataModule: - - """ - DataModule for k2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/fbank2"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=8, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", - ) - - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir /"cuts_musan.jsonl.gz" - ) - - transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) - ) - else: - logging.info("Disable MUSAN") - - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) - # Set the value of num_frame_masks according to Lhotse's version. - # In different Lhotse's versions, the default of num_frame_masks is - # different. - num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature( - SpecAugment.__init__ - ).parameters["num_frame_masks"] - if num_frame_masks_parameter.default == 1: - num_frame_masks = 2 - logging.info(f"Num frame mask: {num_frame_masks}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=num_frame_masks, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - drop_last=self.args.drop_last, - ) - else: - logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80))), - return_cuts=self.args.return_cuts, - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.debug("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else PrecomputedFeatures(), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts, max_duration=self.args.max_duration, shuffle=False - ) - logging.debug("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") - return load_manifest_lazy( - self.args.manifest_dir / "callhome"/"cuts_teltrain_shuf.jsonl.gz" - ) - - @lru_cache() - def dev_cuts(self) -> CutSet: - logging.info("About to get dev cuts") - return load_manifest_lazy(self.args.manifest_dir / "callhome"/ "cuts_devall.jsonl.gz") - - @lru_cache() - def lev_test_cuts(self) -> CutSet: - logging.info("About to get lev test cuts") - return load_manifest_lazy(self.args.manifest_dir / "callhome" / "cuts_levtest.jsonl.gz") - - @lru_cache() - def iraqi_test_cuts(self) -> CutSet: - logging.info("About to get iraqi test cuts") - return load_manifest_lazy(self.args.manifest_dir / "callhome" / "cuts_iraqitest.jsonl.gz") - - @lru_cache() - def gulf_test_cuts(self) -> CutSet: - logging.info("About to get gukf test cuts") - return load_manifest_lazy(self.args.manifest_dir / "callhome" /"cuts_gulftest.jsonl.gz") - - @lru_cache() - def egy_test_cuts(self) -> CutSet: - logging.info("About to get egy test cuts") - return load_manifest_lazy(self.args.manifest_dir / "callhome" /"cuts_egytest.jsonl.gz") - - @lru_cache() - def egy_sup_cuts(self) -> CutSet: - logging.info("About to get egy sup cuts") - return load_manifest_lazy(self.args.manifest_dir / "callhome" /"cuts_egysup.jsonl.gz") - - @lru_cache() - def egy_h5_cuts(self) -> CutSet: - logging.info("About to get egy h5 cuts") - return load_manifest_lazy(self.args.manifest_dir / "callhome" /"cuts_egyh5.jsonl.gz") - \ No newline at end of file diff --git a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/beam_search.py b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/beam_search.py deleted file mode 100644 index 5e9428b60..000000000 --- a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/beam_search.py +++ /dev/null @@ -1,2085 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang -# Xiaoyu Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings -from dataclasses import dataclass, field -from typing import Dict, List, Optional, Tuple, Union - -import k2 -import sentencepiece as spm -import torch -from model import Transducer - -from icefall import NgramLm, NgramLmStateCost -from icefall.decode import Nbest, one_best_decoding -from icefall.rnn_lm.model import RnnLmModel -from icefall.utils import ( - DecodingResults, - add_eos, - add_sos, - get_texts, - get_texts_with_timestamp, -) - - -def fast_beam_search_one_best( - model: Transducer, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """It limits the maximum number of symbols per frame to 1. - - A lattice is first obtained using fast beam search, and then - the shortest path within the lattice is used as the final output. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - temperature=temperature, - ) - - best_path = one_best_decoding(lattice) - - if not return_timestamps: - return get_texts(best_path) - else: - return get_texts_with_timestamp(best_path) - - -def fast_beam_search_nbest_LG( - model: Transducer, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - num_paths: int, - nbest_scale: float = 0.5, - use_double_scores: bool = True, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """It limits the maximum number of symbols per frame to 1. - - The process to get the results is: - - (1) Use fast beam search to get a lattice - - (2) Select `num_paths` paths from the lattice using k2.random_paths() - - (3) Unique the selected paths - - (4) Intersect the selected paths with the lattice and compute the - shortest path from the intersection result - - (5) The path with the largest score is used as the decoding output. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - num_paths: - Number of paths to extract from the decoded lattice. - nbest_scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - use_double_scores: - True to use double precision for computation. False to use - single precision. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - temperature=temperature, - ) - - nbest = Nbest.from_lattice( - lattice=lattice, - num_paths=num_paths, - use_double_scores=use_double_scores, - nbest_scale=nbest_scale, - ) - - # The following code is modified from nbest.intersect() - word_fsa = k2.invert(nbest.fsa) - if hasattr(lattice, "aux_labels"): - # delete token IDs as it is not needed - del word_fsa.aux_labels - word_fsa.scores.zero_() - word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa) - path_to_utt_map = nbest.shape.row_ids(1) - - if hasattr(lattice, "aux_labels"): - # lattice has token IDs as labels and word IDs as aux_labels. - # inv_lattice has word IDs as labels and token IDs as aux_labels - inv_lattice = k2.invert(lattice) - inv_lattice = k2.arc_sort(inv_lattice) - else: - inv_lattice = k2.arc_sort(lattice) - - if inv_lattice.shape[0] == 1: - path_lattice = k2.intersect_device( - inv_lattice, - word_fsa_with_epsilon_loops, - b_to_a_map=torch.zeros_like(path_to_utt_map), - sorted_match_a=True, - ) - else: - path_lattice = k2.intersect_device( - inv_lattice, - word_fsa_with_epsilon_loops, - b_to_a_map=path_to_utt_map, - sorted_match_a=True, - ) - - # path_lattice has word IDs as labels and token IDs as aux_labels - path_lattice = k2.top_sort(k2.connect(path_lattice)) - tot_scores = path_lattice.get_tot_scores( - use_double_scores=use_double_scores, - log_semiring=True, # Note: we always use True - ) - # See https://github.com/k2-fsa/icefall/pull/420 for why - # we always use log_semiring=True - - ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) - best_hyp_indexes = ragged_tot_scores.argmax() - best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes) - - if not return_timestamps: - return get_texts(best_path) - else: - return get_texts_with_timestamp(best_path) - - -def fast_beam_search_nbest( - model: Transducer, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - num_paths: int, - nbest_scale: float = 0.5, - use_double_scores: bool = True, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """It limits the maximum number of symbols per frame to 1. - - The process to get the results is: - - (1) Use fast beam search to get a lattice - - (2) Select `num_paths` paths from the lattice using k2.random_paths() - - (3) Unique the selected paths - - (4) Intersect the selected paths with the lattice and compute the - shortest path from the intersection result - - (5) The path with the largest score is used as the decoding output. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - num_paths: - Number of paths to extract from the decoded lattice. - nbest_scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - use_double_scores: - True to use double precision for computation. False to use - single precision. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - temperature=temperature, - ) - - nbest = Nbest.from_lattice( - lattice=lattice, - num_paths=num_paths, - use_double_scores=use_double_scores, - nbest_scale=nbest_scale, - ) - - # at this point, nbest.fsa.scores are all zeros. - - nbest = nbest.intersect(lattice) - # Now nbest.fsa.scores contains acoustic scores - - max_indexes = nbest.tot_scores().argmax() - - best_path = k2.index_fsa(nbest.fsa, max_indexes) - - if not return_timestamps: - return get_texts(best_path) - else: - return get_texts_with_timestamp(best_path) - - -def fast_beam_search_nbest_oracle( - model: Transducer, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - num_paths: int, - ref_texts: List[List[int]], - use_double_scores: bool = True, - nbest_scale: float = 0.5, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """It limits the maximum number of symbols per frame to 1. - - A lattice is first obtained using fast beam search, and then - we select `num_paths` linear paths from the lattice. The path - that has the minimum edit distance with the given reference transcript - is used as the output. - - This is the best result we can achieve for any nbest based rescoring - methods. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - num_paths: - Number of paths to extract from the decoded lattice. - ref_texts: - A list-of-list of integers containing the reference transcripts. - If the decoding_graph is a trivial_graph, the integer ID is the - BPE token ID. - use_double_scores: - True to use double precision for computation. False to use - single precision. - nbest_scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - temperature=temperature, - ) - - nbest = Nbest.from_lattice( - lattice=lattice, - num_paths=num_paths, - use_double_scores=use_double_scores, - nbest_scale=nbest_scale, - ) - - hyps = nbest.build_levenshtein_graphs() - refs = k2.levenshtein_graph(ref_texts, device=hyps.device) - - levenshtein_alignment = k2.levenshtein_alignment( - refs=refs, - hyps=hyps, - hyp_to_ref_map=nbest.shape.row_ids(1), - sorted_match_ref=True, - ) - - tot_scores = levenshtein_alignment.get_tot_scores( - use_double_scores=False, log_semiring=False - ) - ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) - - max_indexes = ragged_tot_scores.argmax() - - best_path = k2.index_fsa(nbest.fsa, max_indexes) - - if not return_timestamps: - return get_texts(best_path) - else: - return get_texts_with_timestamp(best_path) - - -def fast_beam_search( - model: Transducer, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - temperature: float = 1.0, -) -> k2.Fsa: - """It limits the maximum number of symbols per frame to 1. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - temperature: - Softmax temperature. - Returns: - Return an FsaVec with axes [utt][state][arc] containing the decoded - lattice. Note: When the input graph is a TrivialGraph, the returned - lattice is actually an acceptor. - """ - assert encoder_out.ndim == 3 - - context_size = model.decoder.context_size - vocab_size = model.decoder.vocab_size - - B, T, C = encoder_out.shape - - config = k2.RnntDecodingConfig( - vocab_size=vocab_size, - decoder_history_len=context_size, - beam=beam, - max_contexts=max_contexts, - max_states=max_states, - ) - individual_streams = [] - for i in range(B): - individual_streams.append(k2.RnntDecodingStream(decoding_graph)) - decoding_streams = k2.RnntDecodingStreams(individual_streams, config) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - for t in range(T): - # shape is a RaggedShape of shape (B, context) - # contexts is a Tensor of shape (shape.NumElements(), context_size) - shape, contexts = decoding_streams.get_contexts() - # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 - contexts = contexts.to(torch.int64) - # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) - decoder_out = model.decoder(contexts, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - # current_encoder_out is of shape - # (shape.NumElements(), 1, joiner_dim) - # fmt: off - current_encoder_out = torch.index_select( - encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) - ) - # fmt: on - logits = model.joiner( - current_encoder_out.unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - logits = logits.squeeze(1).squeeze(1) - log_probs = (logits / temperature).log_softmax(dim=-1) - decoding_streams.advance(log_probs) - decoding_streams.terminate_and_flush_to_streams() - lattice = decoding_streams.format_output(encoder_out_lens.tolist()) - - return lattice - - -def greedy_search( - model: Transducer, - encoder_out: torch.Tensor, - max_sym_per_frame: int, - return_timestamps: bool = False, -) -> Union[List[int], DecodingResults]: - """Greedy search for a single utterance. - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - max_sym_per_frame: - Maximum number of symbols per frame. If it is set to 0, the WER - would be 100%. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - unk_id = getattr(model, "unk_id", blank_id) - - device = next(model.parameters()).device - - decoder_input = torch.tensor( - [-1] * (context_size - 1) + [blank_id], device=device, dtype=torch.int64 - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - T = encoder_out.size(1) - t = 0 - hyp = [blank_id] * context_size - - # timestamp[i] is the frame index after subsampling - # on which hyp[i] is decoded - timestamp = [] - - # Maximum symbols per utterance. - max_sym_per_utt = 1000 - - # symbols per frame - sym_per_frame = 0 - - # symbols per utterance decoded so far - sym_per_utt = 0 - - while t < T and sym_per_utt < max_sym_per_utt: - if sym_per_frame >= max_sym_per_frame: - sym_per_frame = 0 - t += 1 - continue - - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) - # fmt: on - logits = model.joiner( - current_encoder_out, decoder_out.unsqueeze(1), project_input=False - ) - # logits is (1, 1, 1, vocab_size) - - y = logits.argmax().item() - if y not in (blank_id, unk_id): - hyp.append(y) - timestamp.append(t) - decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( - 1, context_size - ) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - sym_per_utt += 1 - sym_per_frame += 1 - else: - sym_per_frame = 0 - t += 1 - hyp = hyp[context_size:] # remove blanks - - if not return_timestamps: - return hyp - else: - return DecodingResults( - hyps=[hyp], - timestamps=[timestamp], - ) - - -def greedy_search_batch( - model: Transducer, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C), where N >= 1. - encoder_out_lens: - A 1-D tensor of shape (N,), containing number of valid frames in - encoder_out before padding. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - assert encoder_out.ndim == 3 - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - device = next(model.parameters()).device - - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - hyps = [[-1] * (context_size - 1) + [blank_id] for _ in range(N)] - - # timestamp[n][i] is the frame index after subsampling - # on which hyp[n][i] is decoded - timestamps = [[] for _ in range(N)] - - decoder_input = torch.tensor( - hyps, - device=device, - dtype=torch.int64, - ) # (N, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out: (N, 1, decoder_out_dim) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - for (t, batch_size) in enumerate(batch_size_list): - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) - offset = end - - decoder_out = decoder_out[:batch_size] - - logits = model.joiner( - current_encoder_out, decoder_out.unsqueeze(1), project_input=False - ) - # logits'shape (batch_size, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v not in (blank_id, unk_id): - hyps[i].append(v) - timestamps[i].append(t) - emitted = True - if emitted: - # update decoder output - decoder_input = [h[-context_size:] for h in hyps[:batch_size]] - decoder_input = torch.tensor( - decoder_input, - device=device, - dtype=torch.int64, - ) - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - sorted_ans = [h[context_size:] for h in hyps] - ans = [] - ans_timestamps = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - ans_timestamps.append(timestamps[unsorted_indices[i]]) - - if not return_timestamps: - return ans - else: - return DecodingResults( - hyps=ans, - timestamps=ans_timestamps, - ) - - -@dataclass -class Hypothesis: - # The predicted tokens so far. - # Newly predicted tokens are appended to `ys`. - ys: List[int] - - # The log prob of ys. - # It contains only one entry. - log_prob: torch.Tensor - - # timestamp[i] is the frame index after subsampling - # on which ys[i] is decoded - timestamp: List[int] = field(default_factory=list) - - # the lm score for next token given the current ys - lm_score: Optional[torch.Tensor] = None - - # the RNNLM states (h and c in LSTM) - state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None - - # N-gram LM state - state_cost: Optional[NgramLmStateCost] = None - - @property - def key(self) -> str: - """Return a string representation of self.ys""" - return "_".join(map(str, self.ys)) - - -class HypothesisList(object): - def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None: - """ - Args: - data: - A dict of Hypotheses. Its key is its `value.key`. - """ - if data is None: - self._data = {} - else: - self._data = data - - @property - def data(self) -> Dict[str, Hypothesis]: - return self._data - - def add(self, hyp: Hypothesis) -> None: - """Add a Hypothesis to `self`. - - If `hyp` already exists in `self`, its probability is updated using - `log-sum-exp` with the existed one. - - Args: - hyp: - The hypothesis to be added. - """ - key = hyp.key - if key in self: - old_hyp = self._data[key] # shallow copy - torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) - else: - self._data[key] = hyp - - def get_most_probable(self, length_norm: bool = False) -> Hypothesis: - """Get the most probable hypothesis, i.e., the one with - the largest `log_prob`. - - Args: - length_norm: - If True, the `log_prob` of a hypothesis is normalized by the - number of tokens in it. - Returns: - Return the hypothesis that has the largest `log_prob`. - """ - if length_norm: - return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) - else: - return max(self._data.values(), key=lambda hyp: hyp.log_prob) - - def remove(self, hyp: Hypothesis) -> None: - """Remove a given hypothesis. - - Caution: - `self` is modified **in-place**. - - Args: - hyp: - The hypothesis to be removed from `self`. - Note: It must be contained in `self`. Otherwise, - an exception is raised. - """ - key = hyp.key - assert key in self, f"{key} does not exist" - del self._data[key] - - def filter(self, threshold: torch.Tensor) -> "HypothesisList": - """Remove all Hypotheses whose log_prob is less than threshold. - - Caution: - `self` is not modified. Instead, a new HypothesisList is returned. - - Returns: - Return a new HypothesisList containing all hypotheses from `self` - with `log_prob` being greater than the given `threshold`. - """ - ans = HypothesisList() - for _, hyp in self._data.items(): - if hyp.log_prob > threshold: - ans.add(hyp) # shallow copy - return ans - - def topk(self, k: int) -> "HypothesisList": - """Return the top-k hypothesis.""" - hyps = list(self._data.items()) - - hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] - - ans = HypothesisList(dict(hyps)) - return ans - - def __contains__(self, key: str): - return key in self._data - - def __iter__(self): - return iter(self._data.values()) - - def __len__(self) -> int: - return len(self._data) - - def __str__(self) -> str: - s = [] - for key in self: - s.append(key) - return ", ".join(s) - - -def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: - """Return a ragged shape with axes [utt][num_hyps]. - - Args: - hyps: - len(hyps) == batch_size. It contains the current hypothesis for - each utterance in the batch. - Returns: - Return a ragged shape with 2 axes [utt][num_hyps]. Note that - the shape is on CPU. - """ - num_hyps = [len(h) for h in hyps] - - # torch.cumsum() is inclusive sum, so we put a 0 at the beginning - # to get exclusive sum later. - num_hyps.insert(0, 0) - - num_hyps = torch.tensor(num_hyps) - row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32) - ans = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=row_splits[-1].item() - ) - return ans - - -def modified_beam_search( - model: Transducer, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: int = 4, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. - - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C). - encoder_out_lens: - A 1-D tensor of shape (N,), containing number of valid frames in - encoder_out before padding. - beam: - Number of active paths during the beam search. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - device = next(model.parameters()).device - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - B = [HypothesisList() for _ in range(N)] - for i in range(N): - B[i].add( - Hypothesis( - ys=[blank_id] * context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - timestamp=[], - ) - ) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - finalized_B = [] - for (t, batch_size) in enumerate(batch_size_list): - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - offset = end - - finalized_B = B[batch_size:] + finalized_B - B = B[:batch_size] - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.cat( - [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] - ) # (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) - - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor - # as index, so we use `to(torch.int64)` below. - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - - log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_ys = hyp.ys[:] - new_token = topk_token_indexes[k] - new_timestamp = hyp.timestamp[:] - if new_token not in (blank_id, unk_id): - new_ys.append(new_token) - new_timestamp.append(t) - - new_log_prob = topk_log_probs[k] - new_hyp = Hypothesis( - ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp - ) - B[i].add(new_hyp) - - B = B + finalized_B - best_hyps = [b.get_most_probable(length_norm=True) for b in B] - - sorted_ans = [h.ys[context_size:] for h in best_hyps] - sorted_timestamps = [h.timestamp for h in best_hyps] - ans = [] - ans_timestamps = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) - - if not return_timestamps: - return ans - else: - return DecodingResults( - hyps=ans, - timestamps=ans_timestamps, - ) - - -def _deprecated_modified_beam_search( - model: Transducer, - encoder_out: torch.Tensor, - beam: int = 4, - return_timestamps: bool = False, -) -> Union[List[int], DecodingResults]: - """It limits the maximum number of symbols per frame to 1. - - It decodes only one utterance at a time. We keep it only for reference. - The function :func:`modified_beam_search` should be preferred as it - supports batch decoding. - - - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - beam: - Beam size. - return_timestamps: - Whether to return timestamps. - - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - - device = next(model.parameters()).device - - T = encoder_out.size(1) - - B = HypothesisList() - B.add( - Hypothesis( - ys=[blank_id] * context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - timestamp=[], - ) - ) - encoder_out = model.joiner.encoder_proj(encoder_out) - - for t in range(T): - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) - # current_encoder_out is of shape (1, 1, 1, encoder_out_dim) - # fmt: on - A = list(B) - B = HypothesisList() - - ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A]) - # ys_log_probs is of shape (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyp in A], - device=device, - dtype=torch.int64, - ) - # decoder_input is of shape (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_output is of shape (num_hyps, 1, 1, joiner_dim) - - current_encoder_out = current_encoder_out.expand( - decoder_out.size(0), 1, 1, -1 - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) - # logits is of shape (num_hyps, 1, 1, vocab_size) - logits = logits.squeeze(1).squeeze(1) - - # now logits is of shape (num_hyps, vocab_size) - log_probs = logits.log_softmax(dim=-1) - - log_probs.add_(ys_log_probs) - - log_probs = log_probs.reshape(-1) - topk_log_probs, topk_indexes = log_probs.topk(beam) - - # topk_hyp_indexes are indexes into `A` - topk_hyp_indexes = topk_indexes // logits.size(-1) - topk_token_indexes = topk_indexes % logits.size(-1) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = topk_hyp_indexes.tolist() - topk_token_indexes = topk_token_indexes.tolist() - - for i in range(len(topk_hyp_indexes)): - hyp = A[topk_hyp_indexes[i]] - new_ys = hyp.ys[:] - new_timestamp = hyp.timestamp[:] - new_token = topk_token_indexes[i] - if new_token not in (blank_id, unk_id): - new_ys.append(new_token) - new_timestamp.append(t) - new_log_prob = topk_log_probs[i] - new_hyp = Hypothesis( - ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp - ) - B.add(new_hyp) - - best_hyp = B.get_most_probable(length_norm=True) - ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks - - if not return_timestamps: - return ys - else: - return DecodingResults(hyps=[ys], timestamps=[best_hyp.timestamp]) - - -def beam_search( - model: Transducer, - encoder_out: torch.Tensor, - beam: int = 4, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[int], DecodingResults]: - """ - It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf - - espnet/nets/beam_search_transducer.py#L247 is used as a reference. - - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - beam: - Beam size. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - - device = next(model.parameters()).device - - decoder_input = torch.tensor( - [blank_id] * context_size, - device=device, - dtype=torch.int64, - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - T = encoder_out.size(1) - t = 0 - - B = HypothesisList() - B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0, timestamp=[])) - - max_sym_per_utt = 20000 - - sym_per_utt = 0 - - decoder_cache: Dict[str, torch.Tensor] = {} - - while t < T and sym_per_utt < max_sym_per_utt: - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) - # fmt: on - A = B - B = HypothesisList() - - joint_cache: Dict[str, torch.Tensor] = {} - - # TODO(fangjun): Implement prefix search to update the `log_prob` - # of hypotheses in A - - while True: - y_star = A.get_most_probable() - A.remove(y_star) - - cached_key = y_star.key - - if cached_key not in decoder_cache: - decoder_input = torch.tensor( - [y_star.ys[-context_size:]], - device=device, - dtype=torch.int64, - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - decoder_cache[cached_key] = decoder_out - else: - decoder_out = decoder_cache[cached_key] - - cached_key += f"-t-{t}" - if cached_key not in joint_cache: - logits = model.joiner( - current_encoder_out, - decoder_out.unsqueeze(1), - project_input=False, - ) - - # TODO(fangjun): Scale the blank posterior - log_prob = (logits / temperature).log_softmax(dim=-1) - # log_prob is (1, 1, 1, vocab_size) - log_prob = log_prob.squeeze() - # Now log_prob is (vocab_size,) - joint_cache[cached_key] = log_prob - else: - log_prob = joint_cache[cached_key] - - # First, process the blank symbol - skip_log_prob = log_prob[blank_id] - new_y_star_log_prob = y_star.log_prob + skip_log_prob - - # ys[:] returns a copy of ys - B.add( - Hypothesis( - ys=y_star.ys[:], - log_prob=new_y_star_log_prob, - timestamp=y_star.timestamp[:], - ) - ) - - # Second, process other non-blank labels - values, indices = log_prob.topk(beam + 1) - for i, v in zip(indices.tolist(), values.tolist()): - if i in (blank_id, unk_id): - continue - new_ys = y_star.ys + [i] - new_log_prob = y_star.log_prob + v - new_timestamp = y_star.timestamp + [t] - A.add( - Hypothesis( - ys=new_ys, - log_prob=new_log_prob, - timestamp=new_timestamp, - ) - ) - - # Check whether B contains more than "beam" elements more probable - # than the most probable in A - A_most_probable = A.get_most_probable() - - kept_B = B.filter(A_most_probable.log_prob) - - if len(kept_B) >= beam: - B = kept_B.topk(beam) - break - - t += 1 - - best_hyp = B.get_most_probable(length_norm=True) - ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks - - if not return_timestamps: - return ys - else: - return DecodingResults(hyps=[ys], timestamps=[best_hyp.timestamp]) - - -def fast_beam_search_with_nbest_rescoring( - model: Transducer, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - ngram_lm_scale_list: List[float], - num_paths: int, - G: k2.Fsa, - sp: spm.SentencePieceProcessor, - word_table: k2.SymbolTable, - oov_word: str = "", - use_double_scores: bool = True, - nbest_scale: float = 0.5, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Dict[str, Union[List[List[int]], DecodingResults]]: - """It limits the maximum number of symbols per frame to 1. - A lattice is first obtained using fast beam search, num_path are selected - and rescored using a given language model. The shortest path within the - lattice is used as the final output. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - ngram_lm_scale_list: - A list of floats representing LM score scales. - num_paths: - Number of paths to extract from the decoded lattice. - G: - An FsaVec containing only a single FSA. It is an n-gram LM. - sp: - The BPE model. - word_table: - The word symbol table. - oov_word: - OOV words are replaced with this word. - use_double_scores: - True to use double precision for computation. False to use - single precision. - nbest_scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - Return the decoded result in a dict, where the key has the form - 'ngram_lm_scale_xx' and the value is the decoded results - optionally with timestamps. `xx` is the ngram LM scale value - used during decoding, i.e., 0.1. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - temperature=temperature, - ) - - nbest = Nbest.from_lattice( - lattice=lattice, - num_paths=num_paths, - use_double_scores=use_double_scores, - nbest_scale=nbest_scale, - ) - # at this point, nbest.fsa.scores are all zeros. - - nbest = nbest.intersect(lattice) - # Now nbest.fsa.scores contains acoustic scores - - am_scores = nbest.tot_scores() - - # Now we need to compute the LM scores of each path. - # (1) Get the token IDs of each Path. We assume the decoding_graph - # is an acceptor, i.e., lattice is also an acceptor - tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) # [path][arc] - - tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.labels.contiguous()) - tokens = tokens.remove_values_leq(0) # remove -1 and 0 - - token_list: List[List[int]] = tokens.tolist() - word_list: List[List[str]] = sp.decode(token_list) - - assert isinstance(oov_word, str), oov_word - assert oov_word in word_table, oov_word - oov_word_id = word_table[oov_word] - - word_ids_list: List[List[int]] = [] - - for words in word_list: - this_word_ids = [] - for w in words.split(): - if w in word_table: - this_word_ids.append(word_table[w]) - else: - this_word_ids.append(oov_word_id) - word_ids_list.append(this_word_ids) - - word_fsas = k2.linear_fsa(word_ids_list, device=lattice.device) - word_fsas_with_self_loops = k2.add_epsilon_self_loops(word_fsas) - - num_unique_paths = len(word_ids_list) - - b_to_a_map = torch.zeros( - num_unique_paths, - dtype=torch.int32, - device=lattice.device, - ) - - rescored_word_fsas = k2.intersect_device( - a_fsas=G, - b_fsas=word_fsas_with_self_loops, - b_to_a_map=b_to_a_map, - sorted_match_a=True, - ret_arc_maps=False, - ) - - rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas) - rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas)) - ngram_lm_scores = rescored_word_fsas.get_tot_scores( - use_double_scores=True, - log_semiring=False, - ) - - ans: Dict[str, Union[List[List[int]], DecodingResults]] = {} - for s in ngram_lm_scale_list: - key = f"ngram_lm_scale_{s}" - tot_scores = am_scores.values + s * ngram_lm_scores - ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) - max_indexes = ragged_tot_scores.argmax() - best_path = k2.index_fsa(nbest.fsa, max_indexes) - - if not return_timestamps: - ans[key] = get_texts(best_path) - else: - ans[key] = get_texts_with_timestamp(best_path) - - return ans - - -def fast_beam_search_with_nbest_rnn_rescoring( - model: Transducer, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - ngram_lm_scale_list: List[float], - num_paths: int, - G: k2.Fsa, - sp: spm.SentencePieceProcessor, - word_table: k2.SymbolTable, - rnn_lm_model: torch.nn.Module, - rnn_lm_scale_list: List[float], - oov_word: str = "", - use_double_scores: bool = True, - nbest_scale: float = 0.5, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Dict[str, Union[List[List[int]], DecodingResults]]: - """It limits the maximum number of symbols per frame to 1. - A lattice is first obtained using fast beam search, num_path are selected - and rescored using a given language model and a rnn-lm. - The shortest path within the lattice is used as the final output. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - ngram_lm_scale_list: - A list of floats representing LM score scales. - num_paths: - Number of paths to extract from the decoded lattice. - G: - An FsaVec containing only a single FSA. It is an n-gram LM. - sp: - The BPE model. - word_table: - The word symbol table. - rnn_lm_model: - A rnn-lm model used for LM rescoring - rnn_lm_scale_list: - A list of floats representing RNN score scales. - oov_word: - OOV words are replaced with this word. - use_double_scores: - True to use double precision for computation. False to use - single precision. - nbest_scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - Return the decoded result in a dict, where the key has the form - 'ngram_lm_scale_xx' and the value is the decoded results - optionally with timestamps. `xx` is the ngram LM scale value - used during decoding, i.e., 0.1. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - temperature=temperature, - ) - - nbest = Nbest.from_lattice( - lattice=lattice, - num_paths=num_paths, - use_double_scores=use_double_scores, - nbest_scale=nbest_scale, - ) - # at this point, nbest.fsa.scores are all zeros. - - nbest = nbest.intersect(lattice) - # Now nbest.fsa.scores contains acoustic scores - - am_scores = nbest.tot_scores() - - # Now we need to compute the LM scores of each path. - # (1) Get the token IDs of each Path. We assume the decoding_graph - # is an acceptor, i.e., lattice is also an acceptor - tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) # [path][arc] - - tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.labels.contiguous()) - tokens = tokens.remove_values_leq(0) # remove -1 and 0 - - token_list: List[List[int]] = tokens.tolist() - word_list: List[List[str]] = sp.decode(token_list) - - assert isinstance(oov_word, str), oov_word - assert oov_word in word_table, oov_word - oov_word_id = word_table[oov_word] - - word_ids_list: List[List[int]] = [] - - for words in word_list: - this_word_ids = [] - for w in words.split(): - if w in word_table: - this_word_ids.append(word_table[w]) - else: - this_word_ids.append(oov_word_id) - word_ids_list.append(this_word_ids) - - word_fsas = k2.linear_fsa(word_ids_list, device=lattice.device) - word_fsas_with_self_loops = k2.add_epsilon_self_loops(word_fsas) - - num_unique_paths = len(word_ids_list) - - b_to_a_map = torch.zeros( - num_unique_paths, - dtype=torch.int32, - device=lattice.device, - ) - - rescored_word_fsas = k2.intersect_device( - a_fsas=G, - b_fsas=word_fsas_with_self_loops, - b_to_a_map=b_to_a_map, - sorted_match_a=True, - ret_arc_maps=False, - ) - - rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas) - rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas)) - ngram_lm_scores = rescored_word_fsas.get_tot_scores( - use_double_scores=True, - log_semiring=False, - ) - - # Now RNN-LM - blank_id = model.decoder.blank_id - sos_id = sp.piece_to_id("sos_id") - eos_id = sp.piece_to_id("eos_id") - - sos_tokens = add_sos(tokens, sos_id) - tokens_eos = add_eos(tokens, eos_id) - sos_tokens_row_splits = sos_tokens.shape.row_splits(1) - sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1] - - x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id) - y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id) - - x_tokens = x_tokens.to(torch.int64) - y_tokens = y_tokens.to(torch.int64) - sentence_lengths = sentence_lengths.to(torch.int64) - - rnn_lm_nll = rnn_lm_model(x=x_tokens, y=y_tokens, lengths=sentence_lengths) - assert rnn_lm_nll.ndim == 2 - assert rnn_lm_nll.shape[0] == len(token_list) - rnn_lm_scores = -1 * rnn_lm_nll.sum(dim=1) - - ans: Dict[str, List[List[int]]] = {} - for n_scale in ngram_lm_scale_list: - for rnn_scale in rnn_lm_scale_list: - key = f"ngram_lm_scale_{n_scale}_rnn_lm_scale_{rnn_scale}" - tot_scores = ( - am_scores.values + n_scale * ngram_lm_scores + rnn_scale * rnn_lm_scores - ) - ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) - max_indexes = ragged_tot_scores.argmax() - best_path = k2.index_fsa(nbest.fsa, max_indexes) - - if not return_timestamps: - ans[key] = get_texts(best_path) - else: - ans[key] = get_texts_with_timestamp(best_path) - - return ans - - -def modified_beam_search_ngram_rescoring( - model: Transducer, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - ngram_lm: NgramLm, - ngram_lm_scale: float, - beam: int = 4, - temperature: float = 1.0, -) -> List[List[int]]: - """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. - - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C). - encoder_out_lens: - A 1-D tensor of shape (N,), containing number of valid frames in - encoder_out before padding. - beam: - Number of active paths during the beam search. - temperature: - Softmax temperature. - Returns: - Return a list-of-list of token IDs. ans[i] is the decoding results - for the i-th utterance. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - device = next(model.parameters()).device - lm_scale = ngram_lm_scale - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - B = [HypothesisList() for _ in range(N)] - for i in range(N): - B[i].add( - Hypothesis( - ys=[blank_id] * context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - state_cost=NgramLmStateCost(ngram_lm), - ) - ) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - finalized_B = [] - for batch_size in batch_size_list: - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - offset = end - - finalized_B = B[batch_size:] + finalized_B - B = B[:batch_size] - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.cat( - [ - hyp.log_prob.reshape(1, 1) + hyp.state_cost.lm_score * lm_scale - for hyps in A - for hyp in hyps - ] - ) # (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) - - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor - # as index, so we use `to(torch.int64)` below. - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - - log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - vocab_size = log_probs.size(-1) - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_ys = hyp.ys[:] - new_token = topk_token_indexes[k] - if new_token not in (blank_id, unk_id): - new_ys.append(new_token) - state_cost = hyp.state_cost.forward_one_step(new_token) - else: - state_cost = hyp.state_cost - - # We only keep AM scores in new_hyp.log_prob - new_log_prob = topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale - - new_hyp = Hypothesis( - ys=new_ys, log_prob=new_log_prob, state_cost=state_cost - ) - B[i].add(new_hyp) - - B = B + finalized_B - best_hyps = [b.get_most_probable(length_norm=True) for b in B] - - sorted_ans = [h.ys[context_size:] for h in best_hyps] - ans = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - - return ans - - -def modified_beam_search_rnnlm_shallow_fusion( - model: Transducer, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - sp: spm.SentencePieceProcessor, - rnnlm: RnnLmModel, - rnnlm_scale: float, - beam: int = 4, - return_timestamps: bool = False, -) -> List[List[int]]: - """Modified_beam_search + RNNLM shallow fusion - - Args: - model (Transducer): - The transducer model - encoder_out (torch.Tensor): - Encoder output in (N,T,C) - encoder_out_lens (torch.Tensor): - A 1-D tensor of shape (N,), containing the number of - valid frames in encoder_out before padding. - sp: - Sentence piece generator. - rnnlm (RnnLmModel): - RNNLM - rnnlm_scale (float): - scale of RNNLM in shallow fusion - beam (int, optional): - Beam size. Defaults to 4. - - Returns: - Return a list-of-list of token IDs. ans[i] is the decoding results - for the i-th utterance. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - assert rnnlm is not None - lm_scale = rnnlm_scale - vocab_size = rnnlm.vocab_size - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = model.decoder.blank_id - sos_id = sp.piece_to_id("") - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - device = next(model.parameters()).device - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - # get initial lm score and lm state by scoring the "sos" token - sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) - init_score, init_states = rnnlm.score_token(sos_token) - - B = [HypothesisList() for _ in range(N)] - for i in range(N): - B[i].add( - Hypothesis( - ys=[blank_id] * context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - state=init_states, - lm_score=init_score.reshape(-1), - timestamp=[], - ) - ) - - rnnlm.clean_cache() - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - finalized_B = [] - for (t, batch_size) in enumerate(batch_size_list): - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] # get batch - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - offset = end - - finalized_B = B[batch_size:] + finalized_B - B = B[:batch_size] - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.cat( - [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] - ) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - - log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - """ - for all hyps with a non-blank new token, score this token. - It is a little confusing here because this for-loop - looks very similar to the one below. Here, we go through all - top-k tokens and only add the non-blanks ones to the token_list. - The RNNLM will score those tokens given the LM states. Note that - the variable `scores` is the LM score after seeing the new - non-blank token. - """ - token_list = [] - hs = [] - cs = [] - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_token = topk_token_indexes[k] - if new_token not in (blank_id, unk_id): - assert new_token != 0, new_token - token_list.append([new_token]) - # store the LSTM states - hs.append(hyp.state[0]) - cs.append(hyp.state[1]) - - # forward RNNLM to get new states and scores - if len(token_list) != 0: - tokens_to_score = ( - torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) - ) - - hs = torch.cat(hs, dim=1).to(device) - cs = torch.cat(cs, dim=1).to(device) - scores, lm_states = rnnlm.score_token(tokens_to_score, (hs, cs)) - - count = 0 # index, used to locate score and lm states - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - ys = hyp.ys[:] - - lm_score = hyp.lm_score - state = hyp.state - - hyp_log_prob = topk_log_probs[k] # get score of current hyp - new_token = topk_token_indexes[k] - new_timestamp = hyp.timestamp[:] - if new_token not in (blank_id, unk_id): - - ys.append(new_token) - new_timestamp.append(t) - hyp_log_prob += lm_score[new_token] * lm_scale # add the lm score - - lm_score = scores[count] - state = ( - lm_states[0][:, count, :].unsqueeze(1), - lm_states[1][:, count, :].unsqueeze(1), - ) - count += 1 - - new_hyp = Hypothesis( - ys=ys, - log_prob=hyp_log_prob, - state=state, - lm_score=lm_score, - timestamp=new_timestamp, - ) - B[i].add(new_hyp) - - B = B + finalized_B - best_hyps = [b.get_most_probable(length_norm=True) for b in B] - - sorted_ans = [h.ys[context_size:] for h in best_hyps] - sorted_timestamps = [h.timestamp for h in best_hyps] - ans = [] - ans_timestamps = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) - - if not return_timestamps: - return ans - else: - return DecodingResults( - tokens=ans, - timestamps=ans_timestamps, - ) diff --git a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/beam_search.py b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/beam_search.py new file mode 120000 index 000000000..e24eca39f --- /dev/null +++ b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/beam_search_old.py b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/beam_search_old.py deleted file mode 100644 index ce8b04afd..000000000 --- a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/beam_search_old.py +++ /dev/null @@ -1,977 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings -from dataclasses import dataclass -from typing import Dict, List, Optional - -import k2 -import torch -from model import Transducer - -from icefall.decode import Nbest, one_best_decoding -from icefall.utils import get_texts - - -def fast_beam_search_one_best( - model: Transducer, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, -) -> List[List[int]]: - """It limits the maximum number of symbols per frame to 1. - - A lattice is first obtained using modified beam search, and then - the shortest path within the lattice is used as the final output. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a HLG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - Returns: - Return the decoded result. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - ) - - best_path = one_best_decoding(lattice) - hyps = get_texts(best_path) - return hyps - - -def fast_beam_search_nbest_oracle( - model: Transducer, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - num_paths: int, - ref_texts: List[List[int]], - use_double_scores: bool = True, - nbest_scale: float = 0.5, -) -> List[List[int]]: - """It limits the maximum number of symbols per frame to 1. - - A lattice is first obtained using modified beam search, and then - we select `num_paths` linear paths from the lattice. The path - that has the minimum edit distance with the given reference transcript - is used as the output. - - This is the best result we can achieve for any nbest based rescoring - methods. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a HLG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - num_paths: - Number of paths to extract from the decoded lattice. - ref_texts: - A list-of-list of integers containing the reference transcripts. - If the decoding_graph is a trivial_graph, the integer ID is the - BPE token ID. - use_double_scores: - True to use double precision for computation. False to use - single precision. - nbest_scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - - Returns: - Return the decoded result. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - ) - - nbest = Nbest.from_lattice( - lattice=lattice, - num_paths=num_paths, - use_double_scores=use_double_scores, - nbest_scale=nbest_scale, - ) - - hyps = nbest.build_levenshtein_graphs() - refs = k2.levenshtein_graph(ref_texts, device=hyps.device) - - levenshtein_alignment = k2.levenshtein_alignment( - refs=refs, - hyps=hyps, - hyp_to_ref_map=nbest.shape.row_ids(1), - sorted_match_ref=True, - ) - - tot_scores = levenshtein_alignment.get_tot_scores( - use_double_scores=False, log_semiring=False - ) - ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) - - max_indexes = ragged_tot_scores.argmax() - - best_path = k2.index_fsa(nbest.fsa, max_indexes) - - hyps = get_texts(best_path) - return hyps - - -def fast_beam_search( - model: Transducer, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, -) -> k2.Fsa: - """It limits the maximum number of symbols per frame to 1. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a HLG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - Returns: - Return an FsaVec with axes [utt][state][arc] containing the decoded - lattice. Note: When the input graph is a TrivialGraph, the returned - lattice is actually an acceptor. - """ - assert encoder_out.ndim == 3 - - context_size = model.decoder.context_size - vocab_size = model.decoder.vocab_size - - B, T, C = encoder_out.shape - - config = k2.RnntDecodingConfig( - vocab_size=vocab_size, - decoder_history_len=context_size, - beam=beam, - max_contexts=max_contexts, - max_states=max_states, - ) - individual_streams = [] - for i in range(B): - individual_streams.append(k2.RnntDecodingStream(decoding_graph)) - decoding_streams = k2.RnntDecodingStreams(individual_streams, config) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - for t in range(T): - # shape is a RaggedShape of shape (B, context) - # contexts is a Tensor of shape (shape.NumElements(), context_size) - shape, contexts = decoding_streams.get_contexts() - # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 - contexts = contexts.to(torch.int64) - # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) - decoder_out = model.decoder(contexts, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - # current_encoder_out is of shape - # (shape.NumElements(), 1, joiner_dim) - # fmt: off - current_encoder_out = torch.index_select( - encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) - ) - # fmt: on - logits = model.joiner( - current_encoder_out.unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - logits = logits.squeeze(1).squeeze(1) - log_probs = logits.log_softmax(dim=-1) - decoding_streams.advance(log_probs) - decoding_streams.terminate_and_flush_to_streams() - lattice = decoding_streams.format_output(encoder_out_lens.tolist()) - - return lattice - - -def greedy_search( - model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int -) -> List[int]: - """Greedy search for a single utterance. - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - max_sym_per_frame: - Maximum number of symbols per frame. If it is set to 0, the WER - would be 100%. - Returns: - Return the decoded result. - """ - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - unk_id = getattr(model, "unk_id", blank_id) - - device = next(model.parameters()).device - - decoder_input = torch.tensor( - [blank_id] * context_size, device=device, dtype=torch.int64 - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - T = encoder_out.size(1) - t = 0 - hyp = [blank_id] * context_size - - # Maximum symbols per utterance. - max_sym_per_utt = 1000 - - # symbols per frame - sym_per_frame = 0 - - # symbols per utterance decoded so far - sym_per_utt = 0 - - while t < T and sym_per_utt < max_sym_per_utt: - if sym_per_frame >= max_sym_per_frame: - sym_per_frame = 0 - t += 1 - continue - - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) - # fmt: on - logits = model.joiner( - current_encoder_out, decoder_out.unsqueeze(1), project_input=False - ) - # logits is (1, 1, 1, vocab_size) - - y = logits.argmax().item() - if y not in (blank_id, unk_id): - hyp.append(y) - decoder_input = torch.tensor( - [hyp[-context_size:]], device=device - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - sym_per_utt += 1 - sym_per_frame += 1 - else: - sym_per_frame = 0 - t += 1 - hyp = hyp[context_size:] # remove blanks - - return hyp - - -def greedy_search_batch( - model: Transducer, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, -) -> List[List[int]]: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C), where N >= 1. - encoder_out_lens: - A 1-D tensor of shape (N,), containing number of valid frames in - encoder_out before padding. - Returns: - Return a list-of-list of token IDs containing the decoded results. - len(ans) equals to encoder_out.size(0). - """ - assert encoder_out.ndim == 3 - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - device = next(model.parameters()).device - - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - hyps = [[blank_id] * context_size for _ in range(N)] - - decoder_input = torch.tensor( - hyps, - device=device, - dtype=torch.int64, - ) # (N, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out: (N, 1, decoder_out_dim) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - for batch_size in batch_size_list: - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) - offset = end - - decoder_out = decoder_out[:batch_size] - - logits = model.joiner( - current_encoder_out, decoder_out.unsqueeze(1), project_input=False - ) - # logits'shape (batch_size, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v not in (blank_id, unk_id): - hyps[i].append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = [h[-context_size:] for h in hyps[:batch_size]] - decoder_input = torch.tensor( - decoder_input, - device=device, - dtype=torch.int64, - ) - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - sorted_ans = [h[context_size:] for h in hyps] - ans = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - - return ans - - -@dataclass -class Hypothesis: - # The predicted tokens so far. - # Newly predicted tokens are appended to `ys`. - ys: List[int] - - # The log prob of ys. - # It contains only one entry. - log_prob: torch.Tensor - - @property - def key(self) -> str: - """Return a string representation of self.ys""" - return "_".join(map(str, self.ys)) - - -class HypothesisList(object): - def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None: - """ - Args: - data: - A dict of Hypotheses. Its key is its `value.key`. - """ - if data is None: - self._data = {} - else: - self._data = data - - @property - def data(self) -> Dict[str, Hypothesis]: - return self._data - - def add(self, hyp: Hypothesis) -> None: - """Add a Hypothesis to `self`. - - If `hyp` already exists in `self`, its probability is updated using - `log-sum-exp` with the existed one. - - Args: - hyp: - The hypothesis to be added. - """ - key = hyp.key - if key in self: - old_hyp = self._data[key] # shallow copy - torch.logaddexp( - old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob - ) - else: - self._data[key] = hyp - - def get_most_probable(self, length_norm: bool = False) -> Hypothesis: - """Get the most probable hypothesis, i.e., the one with - the largest `log_prob`. - - Args: - length_norm: - If True, the `log_prob` of a hypothesis is normalized by the - number of tokens in it. - Returns: - Return the hypothesis that has the largest `log_prob`. - """ - if length_norm: - return max( - self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) - ) - else: - return max(self._data.values(), key=lambda hyp: hyp.log_prob) - - def remove(self, hyp: Hypothesis) -> None: - """Remove a given hypothesis. - - Caution: - `self` is modified **in-place**. - - Args: - hyp: - The hypothesis to be removed from `self`. - Note: It must be contained in `self`. Otherwise, - an exception is raised. - """ - key = hyp.key - assert key in self, f"{key} does not exist" - del self._data[key] - - def filter(self, threshold: torch.Tensor) -> "HypothesisList": - """Remove all Hypotheses whose log_prob is less than threshold. - - Caution: - `self` is not modified. Instead, a new HypothesisList is returned. - - Returns: - Return a new HypothesisList containing all hypotheses from `self` - with `log_prob` being greater than the given `threshold`. - """ - ans = HypothesisList() - for _, hyp in self._data.items(): - if hyp.log_prob > threshold: - ans.add(hyp) # shallow copy - return ans - - def topk(self, k: int) -> "HypothesisList": - """Return the top-k hypothesis.""" - hyps = list(self._data.items()) - - hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] - - ans = HypothesisList(dict(hyps)) - return ans - - def __contains__(self, key: str): - return key in self._data - - def __iter__(self): - return iter(self._data.values()) - - def __len__(self) -> int: - return len(self._data) - - def __str__(self) -> str: - s = [] - for key in self: - s.append(key) - return ", ".join(s) - - -def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: - """Return a ragged shape with axes [utt][num_hyps]. - - Args: - hyps: - len(hyps) == batch_size. It contains the current hypothesis for - each utterance in the batch. - Returns: - Return a ragged shape with 2 axes [utt][num_hyps]. Note that - the shape is on CPU. - """ - num_hyps = [len(h) for h in hyps] - - # torch.cumsum() is inclusive sum, so we put a 0 at the beginning - # to get exclusive sum later. - num_hyps.insert(0, 0) - - num_hyps = torch.tensor(num_hyps) - row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32) - ans = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=row_splits[-1].item() - ) - return ans - - -def modified_beam_search( - model: Transducer, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: int = 4, -) -> List[List[int]]: - """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. - - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C). - encoder_out_lens: - A 1-D tensor of shape (N,), containing number of valid frames in - encoder_out before padding. - beam: - Number of active paths during the beam search. - Returns: - Return a list-of-list of token IDs. ans[i] is the decoding results - for the i-th utterance. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - device = next(model.parameters()).device - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - B = [HypothesisList() for _ in range(N)] - for i in range(N): - B[i].add( - Hypothesis( - ys=[blank_id] * context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - ) - ) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - finalized_B = [] - for batch_size in batch_size_list: - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - offset = end - - finalized_B = B[batch_size:] + finalized_B - B = B[:batch_size] - - hyps_shape = _get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.cat( - [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] - ) # (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) - - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor - # as index, so we use `to(torch.int64)` below. - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - - log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) - - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_ys = hyp.ys[:] - new_token = topk_token_indexes[k] - if new_token not in (blank_id, unk_id): - new_ys.append(new_token) - - new_log_prob = topk_log_probs[k] - new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) - B[i].add(new_hyp) - - B = B + finalized_B - best_hyps = [b.get_most_probable(length_norm=True) for b in B] - - sorted_ans = [h.ys[context_size:] for h in best_hyps] - ans = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - - return ans - - -def _deprecated_modified_beam_search( - model: Transducer, - encoder_out: torch.Tensor, - beam: int = 4, -) -> List[int]: - """It limits the maximum number of symbols per frame to 1. - - It decodes only one utterance at a time. We keep it only for reference. - The function :func:`modified_beam_search` should be preferred as it - supports batch decoding. - - - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - beam: - Beam size. - Returns: - Return the decoded result. - """ - - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - - device = next(model.parameters()).device - - T = encoder_out.size(1) - - B = HypothesisList() - B.add( - Hypothesis( - ys=[blank_id] * context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - ) - ) - encoder_out = model.joiner.encoder_proj(encoder_out) - - for t in range(T): - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) - # current_encoder_out is of shape (1, 1, 1, encoder_out_dim) - # fmt: on - A = list(B) - B = HypothesisList() - - ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A]) - # ys_log_probs is of shape (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyp in A], - device=device, - dtype=torch.int64, - ) - # decoder_input is of shape (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_output is of shape (num_hyps, 1, 1, joiner_dim) - - current_encoder_out = current_encoder_out.expand( - decoder_out.size(0), 1, 1, -1 - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) - # logits is of shape (num_hyps, 1, 1, vocab_size) - logits = logits.squeeze(1).squeeze(1) - - # now logits is of shape (num_hyps, vocab_size) - log_probs = logits.log_softmax(dim=-1) - - log_probs.add_(ys_log_probs) - - log_probs = log_probs.reshape(-1) - topk_log_probs, topk_indexes = log_probs.topk(beam) - - # topk_hyp_indexes are indexes into `A` - topk_hyp_indexes = topk_indexes // logits.size(-1) - topk_token_indexes = topk_indexes % logits.size(-1) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = topk_hyp_indexes.tolist() - topk_token_indexes = topk_token_indexes.tolist() - - for i in range(len(topk_hyp_indexes)): - hyp = A[topk_hyp_indexes[i]] - new_ys = hyp.ys[:] - new_token = topk_token_indexes[i] - if new_token not in (blank_id, unk_id): - new_ys.append(new_token) - new_log_prob = topk_log_probs[i] - new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) - B.add(new_hyp) - - best_hyp = B.get_most_probable(length_norm=True) - ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks - - return ys - - -def beam_search( - model: Transducer, - encoder_out: torch.Tensor, - beam: int = 4, -) -> List[int]: - """ - It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf - - espnet/nets/beam_search_transducer.py#L247 is used as a reference. - - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - beam: - Beam size. - Returns: - Return the decoded result. - """ - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - - device = next(model.parameters()).device - - decoder_input = torch.tensor( - [blank_id] * context_size, - device=device, - dtype=torch.int64, - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - T = encoder_out.size(1) - t = 0 - - B = HypothesisList() - B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0)) - - max_sym_per_utt = 20000 - - sym_per_utt = 0 - - decoder_cache: Dict[str, torch.Tensor] = {} - - while t < T and sym_per_utt < max_sym_per_utt: - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) - # fmt: on - A = B - B = HypothesisList() - - joint_cache: Dict[str, torch.Tensor] = {} - - # TODO(fangjun): Implement prefix search to update the `log_prob` - # of hypotheses in A - - while True: - y_star = A.get_most_probable() - A.remove(y_star) - - cached_key = y_star.key - - if cached_key not in decoder_cache: - decoder_input = torch.tensor( - [y_star.ys[-context_size:]], - device=device, - dtype=torch.int64, - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - decoder_cache[cached_key] = decoder_out - else: - decoder_out = decoder_cache[cached_key] - - cached_key += f"-t-{t}" - if cached_key not in joint_cache: - logits = model.joiner( - current_encoder_out, - decoder_out.unsqueeze(1), - project_input=False, - ) - - # TODO(fangjun): Scale the blank posterior - log_prob = logits.log_softmax(dim=-1) - # log_prob is (1, 1, 1, vocab_size) - log_prob = log_prob.squeeze() - # Now log_prob is (vocab_size,) - joint_cache[cached_key] = log_prob - else: - log_prob = joint_cache[cached_key] - - # First, process the blank symbol - skip_log_prob = log_prob[blank_id] - new_y_star_log_prob = y_star.log_prob + skip_log_prob - - # ys[:] returns a copy of ys - B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) - - # Second, process other non-blank labels - values, indices = log_prob.topk(beam + 1) - for i, v in zip(indices.tolist(), values.tolist()): - if i in (blank_id, unk_id): - continue - new_ys = y_star.ys + [i] - new_log_prob = y_star.log_prob + v - A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob)) - - # Check whether B contains more than "beam" elements more probable - # than the most probable in A - A_most_probable = A.get_most_probable() - - kept_B = B.filter(A_most_probable.log_prob) - - if len(kept_B) >= beam: - B = kept_B.topk(beam) - break - - t += 1 - - best_hyp = B.get_most_probable(length_norm=True) - ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks - return ys diff --git a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/dataloader_invest.py b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/dataloader_invest.py deleted file mode 100644 index 96a45b87d..000000000 --- a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/dataloader_invest.py +++ /dev/null @@ -1,93 +0,0 @@ -import argparse -import inspect -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, Optional -import pdb -import torch -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy -from lhotse.dataset import ( - CutConcatenate, - CutMix, - DynamicBucketingSampler, - K2SpeechRecognitionDataset, - PrecomputedFeatures, - SingleCutSampler, - SpecAugment, -) - -from lhotse import RecordingSet, SupervisionSet, CutSet -from lhotse.dataset.input_strategies import OnTheFlyFeatures -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -def train_dataloaders( - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, -) -> DataLoader: - - transforms = [] - bucketing_sampler = True - logging.info("About to create train dataset") - pdb.set_trace() - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=True, - ) - - if bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=100, - shuffle=True, - num_buckets=30, - drop_last=True, - ) - logging.info("About to create train dataloader") - - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=0, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - -# creating the cut -dir = Path("/alt-arabic/speech/amir/kanari_models/k2/stateless_tranducer5/data/fbank2/callhome") -# sup = SupervisionSet.from_file(dir / 'supervisions.jsonl.gz') -# rec = RecordingSet.from_file(dir / 'recordings.jsonl.gz') -# cuts = CutSet.from_manifests(recordings=rec, supervisions=sup) -cuts = load_manifest_lazy(dir / 'cuts_levtest.jsonl.gz') -print('loaded') - -epoch = 10 -train_dl = train_dataloaders(cuts) -train_dl.sampler.set_epoch(epoch - 1) - -pdb.set_trace() -for batch_idx, batch in enumerate(train_dl): - cur_batch_idx = batch_idx - batch_size = len(batch["supervisions"]["text"]) - print(batch["inputs"]) - diff --git a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/decode_st.py b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/decode.py similarity index 100% rename from egs/iwslt22_ta/ST/pruned_transducer_stateless5/decode_st.py rename to egs/iwslt22_ta/ST/pruned_transducer_stateless5/decode.py diff --git a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/decode_asr.py b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/decode_asr.py deleted file mode 100755 index 591870029..000000000 --- a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/decode_asr.py +++ /dev/null @@ -1,960 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Johns Hopkins (authors: Amir Hussein) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -(3) modified beam search - ./pruned_transducer_stateless5/decode_asr.py \ - --epoch 15 \ - --beam-size 20 \ - --avg 5 \ - --exp-dir ./pruned_transducer_stateless5/exp_asr \ - --max-duration 400 \ - --decoding-method modified_beam_search \ - --max-sym-per-frame 1 \ - --num-encoder-layers 12 \ - --dim-feedforward 1024 \ - --nhead 8 \ - --encoder-dim 256 \ - --decoder-dim 256 \ - --joiner-dim 256 \ - --use-averaged-model true - -""" - - -import argparse -import logging -import math -import pdb -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple -from lhotse.qa import validate_cut -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import IWSLTDialectSTDataModule -from beam_search import ( - beam_search, - fast_beam_search_nbest, - fast_beam_search_nbest_LG, - fast_beam_search_nbest_oracle, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, - modified_beam_search_rnnlm_shallow_fusion, -) -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.rnn_lm.model import RnnLmModel -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless5/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_ta_1000/bpe.model", - help="Path to source data BPE model", - ) - parser.add_argument( - "--bpe-tgt-model", - type=str, - default="data/lang_bpe_en_1000/bpe.model", - help="Path to target data BPE model", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/ang_bpe_ta_1000", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--lang-tgt-dir", - type=Path, - default="data/lang_bpe_en_1000", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - - fast_beam_search_LG - - fast_beam_search_nbest - - fast_beam_search_nbest_oracle - - fast_beam_search_nbest_LG - - modified_beam_search_rnnlm_shallow_fusion # for rnn lm shallow fusion - If you use fast_beam_search_nbest_LG, you have to specify - `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, fast_beam_search_LG, - fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--decode-chunk-size", - type=int, - default=16, - help="The chunk size for decoding (in frames after subsampling)", - ) - - parser.add_argument( - "--left-context", - type=int, - default=64, - help="left context can be seen during decoding (in frames after subsampling)", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is fast_beam_search_LG, - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is fast_beam_search_LG, - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--simulate-streaming", - type=str2bool, - default=False, - help="""Whether to simulate streaming in decoding, this is a good way to - test a streaming model. - """, - ) - - parser.add_argument( - "--rnn-lm-scale", - type=float, - default=0.0, - help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. - It specifies the path to RNN LM exp dir. - """, - ) - - parser.add_argument( - "--rnn-lm-exp-dir", - type=str, - default="rnn_lm/exp", - help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. - It specifies the path to RNN LM exp dir. - """, - ) - - parser.add_argument( - "--rnn-lm-epoch", - type=int, - default=7, - help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. - It specifies the checkpoint to use. - """, - ) - - parser.add_argument( - "--rnn-lm-avg", - type=int, - default=2, - help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. - It specifies the number of checkpoints to average. - """, - ) - - parser.add_argument( - "--rnn-lm-embedding-dim", - type=int, - default=2048, - help="Embedding dim of the model", - ) - - parser.add_argument( - "--rnn-lm-hidden-dim", - type=int, - default=2048, - help="Hidden dim of the model", - ) - - parser.add_argument( - "--rnn-lm-num-layers", - type=int, - default=4, - help="Number of RNN layers the model", - ) - parser.add_argument( - "--rnn-lm-tie-weights", - type=str2bool, - default=False, - help="""True to share the weights between the input embedding layer and the - last output linear layer - """, - ) - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - batch: dict, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, - rnnlm: Optional[RnnLmModel] = None, - rnnlm_scale: float = 1.0, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or LG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - if params.simulate_streaming: - feature_lens += params.left_context - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.left_context), - value=LOG_EPS, - ) - encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( - x=feature, - x_lens=feature_lens, - chunk_size=params.decode_chunk_size, - left_context=params.left_context, - simulate_streaming=True, - ) - else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens) - - hyps = [] - - if ( - params.decoding_method == "fast_beam_search" - or params.decoding_method == "fast_beam_search_LG" - ): - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - if params.decoding_method == "fast_beam_search": - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - else: - for hyp in hyp_tokens: - hyps.append([word_table[i] for i in hyp]) - elif params.decoding_method == "fast_beam_search_nbest_LG": - hyp_tokens = fast_beam_search_nbest_LG( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in hyp_tokens: - hyps.append([word_table[i] for i in hyp]) - elif params.decoding_method == "fast_beam_search_nbest": - hyp_tokens = fast_beam_search_nbest( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "fast_beam_search_nbest_oracle": - hyp_tokens = fast_beam_search_nbest_oracle( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - ref_texts=sp.encode(supervisions["text"]), - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion": - hyp_tokens = modified_beam_search_rnnlm_shallow_fusion( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - sp=sp, - rnnlm=rnnlm, - rnnlm_scale=rnnlm_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp).split()) - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif "fast_beam_search" in params.decoding_method: - key = f"beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" - if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - - return {key: hyps} - else: - return {f"beam_size_{params.beam_size}": hyps} -def remove_short_and_long_utt(c): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 0.5 or c.duration > 30.0: - #logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - #) - return False - if c.supervisions == []: - return False - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./conformer.py, the conv module uses the following expression - # for subsamplin - - return True - -# def remove_seg(c): -# if c.supervisions[0].id != 'fla_0102_1_0B_00107': -# return True -# else: -# return False - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, - rnnlm: Optional[RnnLmModel] = None, - rnnlm_scale: float = 1.0, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or LG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - logging.info(f"Decoding {batch_idx}-th batch") - - hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - word_table=word_table, - batch=batch, - rnnlm=rnnlm, - rnnlm_scale=rnnlm_scale, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / - f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / - f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - IWSLTDialectSTDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "fast_beam_search_LG", - "fast_beam_search_nbest", - "fast_beam_search_nbest_LG", - "fast_beam_search_nbest_oracle", - "modified_beam_search", - "modified_beam_search_rnnlm_shallow_fusion", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if params.simulate_streaming: - params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" - params.suffix += f"-left-context-{params.left_context}" - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - if "nbest" in params.decoding_method: - params.suffix += f"-nbest-scale-{params.nbest_scale}" - params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - if params.simulate_streaming: - assert ( - params.causal_convolution - ), "Decoding in streaming requires causal convolution" - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - rnn_lm_model = None - rnn_lm_scale = params.rnn_lm_scale - if params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion": - rnn_lm_model = RnnLmModel( - vocab_size=params.vocab_size, - embedding_dim=params.rnn_lm_embedding_dim, - hidden_dim=params.rnn_lm_hidden_dim, - num_layers=params.rnn_lm_num_layers, - tie_weights=params.rnn_lm_tie_weights, - ) - assert params.rnn_lm_avg == 1 - - load_checkpoint( - f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt", - rnn_lm_model, - ) - rnn_lm_model.to(device) - rnn_lm_model.eval() - - if "fast_beam_search" in params.decoding_method: - if "LG" in params.decoding_method: - lexicon = Lexicon(params.lang_dir) - word_table = lexicon.word_table - lg_filename = params.lang_dir / "LG.pt" - logging.info(f"Loading {lg_filename}") - decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) - decoding_graph.scores *= params.ngram_lm_scale - else: - word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device) - else: - decoding_graph = None - word_table = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - iwslt_ta = IWSLTDialectSTDataModule(args) - - test_cuts = iwslt_ta.test_cuts() - dev_cuts = iwslt_ta.dev_cuts() - - # lev_test_cuts = lev_test_cuts.filter(remove_short_and_long_utt) - # # lev_test_cuts = lev_test_cuts.filter(remove_seg) - # gulf_test_cuts = gulf_test_cuts.filter(remove_short_and_long_utt) - # egy_test_cuts = egy_test_cuts.filter(remove_short_and_long_utt) - # egy_h5_cuts = egy_sup_cuts.filter(remove_short_and_long_utt) - # egy_sup_cuts = egy_h5_cuts.filter(remove_short_and_long_utt) - - test_dl = iwslt_ta.test_dataloaders(test_cuts) - dev_dl = iwslt_ta.test_dataloaders(dev_cuts) - - test_sets = ["test", "dev"] - test_all_dl = [test_dl, dev_dl] - - - for test_set, test_dl in zip(test_sets, test_all_dl): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - sp=sp, - word_table=word_table, - decoding_graph=decoding_graph, - rnnlm=rnn_lm_model, - rnnlm_scale=rnn_lm_scale, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/decode_stream.py b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/decode_stream.py deleted file mode 100755 index e522943c0..000000000 --- a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/decode_stream.py +++ /dev/null @@ -1,146 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from typing import List, Optional, Tuple - -import k2 -import torch -from beam_search import Hypothesis, HypothesisList - -from icefall.utils import AttributeDict - - -class DecodeStream(object): - def __init__( - self, - params: AttributeDict, - cut_id: str, - initial_states: List[torch.Tensor], - decoding_graph: Optional[k2.Fsa] = None, - device: torch.device = torch.device("cpu"), - ) -> None: - """ - Args: - initial_states: - Initial decode states of the model, e.g. the return value of - `get_init_state` in conformer.py - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a HLG. - Used only when decoding_method is fast_beam_search. - device: - The device to run this stream. - """ - if params.decoding_method == "fast_beam_search": - assert decoding_graph is not None - assert device == decoding_graph.device - - self.params = params - self.cut_id = cut_id - self.LOG_EPS = math.log(1e-10) - - self.states = initial_states - - # It contains a 2-D tensors representing the feature frames. - self.features: torch.Tensor = None - - self.num_frames: int = 0 - # how many frames have been processed. (before subsampling). - # we only modify this value in `func:get_feature_frames`. - self.num_processed_frames: int = 0 - - self._done: bool = False - - # The transcript of current utterance. - self.ground_truth: str = "" - - # The decoding result (partial or final) of current utterance. - self.hyp: List = [] - - # how many frames have been processed, after subsampling (i.e. a - # cumulative sum of the second return value of - # encoder.streaming_forward - self.done_frames: int = 0 - - self.pad_length = (params.right_context + 2) * params.subsampling_factor + 3 - - if params.decoding_method == "greedy_search": - self.hyp = [params.blank_id] * params.context_size - elif params.decoding_method == "modified_beam_search": - self.hyps = HypothesisList() - self.hyps.add( - Hypothesis( - ys=[params.blank_id] * params.context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - ) - ) - elif params.decoding_method == "fast_beam_search": - # The rnnt_decoding_stream for fast_beam_search. - self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( - decoding_graph - ) - else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - - @property - def done(self) -> bool: - """Return True if all the features are processed.""" - return self._done - - @property - def id(self) -> str: - return self.cut_id - - def set_features( - self, - features: torch.Tensor, - ) -> None: - """Set features tensor of current utterance.""" - assert features.dim() == 2, features.dim() - self.features = torch.nn.functional.pad( - features, - (0, 0, 0, self.pad_length), - mode="constant", - value=self.LOG_EPS, - ) - self.num_frames = self.features.size(0) - - def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]: - """Consume chunk_size frames of features""" - chunk_length = chunk_size + self.pad_length - - ret_length = min(self.num_frames - self.num_processed_frames, chunk_length) - - ret_features = self.features[ - self.num_processed_frames : self.num_processed_frames + ret_length # noqa - ] - - self.num_processed_frames += chunk_size - if self.num_processed_frames >= self.num_frames: - self._done = True - - return ret_features, ret_length - - def decoding_result(self) -> List[int]: - """Obtain current decoding result.""" - if self.params.decoding_method == "greedy_search": - return self.hyp[self.params.context_size :] # noqa - elif self.params.decoding_method == "modified_beam_search": - best_hyp = self.hyps.get_most_probable(length_norm=True) - return best_hyp.ys[self.params.context_size :] # noqa - else: - assert self.params.decoding_method == "fast_beam_search" - return self.hyp diff --git a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/decode_stream.py b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/decode_stream.py new file mode 120000 index 000000000..d59ef95f7 --- /dev/null +++ b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/decode_stream.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/decode_stream.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/decoder.py b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/decoder.py deleted file mode 100644 index b6d94aaf1..000000000 --- a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/decoder.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn as nn -import torch.nn.functional as F -from scaling import ScaledConv1d, ScaledEmbedding - - -class Decoder(nn.Module): - """This class modifies the stateless decoder from the following paper: - - RNN-transducer with stateless prediction network - https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 - - It removes the recurrent connection from the decoder, i.e., the prediction - network. Different from the above paper, it adds an extra Conv1d - right after the embedding layer. - - TODO: Implement https://arxiv.org/pdf/2109.07513.pdf - """ - - def __init__( - self, - vocab_size: int, - decoder_dim: int, - blank_id: int, - context_size: int, - ): - """ - Args: - vocab_size: - Number of tokens of the modeling unit including blank. - decoder_dim: - Dimension of the input embedding, and of the decoder output. - blank_id: - The ID of the blank symbol. - context_size: - Number of previous words to use to predict the next word. - 1 means bigram; 2 means trigram. n means (n+1)-gram. - """ - super().__init__() - - self.embedding = ScaledEmbedding( - num_embeddings=vocab_size, - embedding_dim=decoder_dim, - padding_idx=blank_id, - ) - self.blank_id = blank_id - - assert context_size >= 1, context_size - self.context_size = context_size - self.vocab_size = vocab_size - if context_size > 1: - self.conv = ScaledConv1d( - in_channels=decoder_dim, - out_channels=decoder_dim, - kernel_size=context_size, - padding=0, - groups=decoder_dim, - bias=False, - ) - - def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: - """ - Args: - y: - A 2-D tensor of shape (N, U). - need_pad: - True to left pad the input. Should be True during training. - False to not pad the input. Should be False during inference. - Returns: - Return a tensor of shape (N, U, decoder_dim). - """ - y = y.to(torch.int64) - embedding_out = self.embedding(y) - if self.context_size > 1: - embedding_out = embedding_out.permute(0, 2, 1) - if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) - else: - # During inference time, there is no need to do extra padding - # as we only need one output - assert embedding_out.size(-1) == self.context_size - embedding_out = self.conv(embedding_out) - embedding_out = embedding_out.permute(0, 2, 1) - embedding_out = F.relu(embedding_out) - return embedding_out diff --git a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/decoder.py b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/decoder.py new file mode 120000 index 000000000..722e1c894 --- /dev/null +++ b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/decoder.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/encoder_interface.py b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/encoder_interface.py deleted file mode 100644 index 257facce4..000000000 --- a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/encoder_interface.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Tuple - -import torch -import torch.nn as nn - - -class EncoderInterface(nn.Module): - def forward( - self, x: torch.Tensor, x_lens: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - x: - A tensor of shape (batch_size, input_seq_len, num_features) - containing the input features. - x_lens: - A tensor of shape (batch_size,) containing the number of frames - in `x` before padding. - Returns: - Return a tuple containing two tensors: - - encoder_out, a tensor of (batch_size, out_seq_len, output_dim) - containing unnormalized probabilities, i.e., the output of a - linear layer. - - encoder_out_lens, a tensor of shape (batch_size,) containing - the number of frames in `encoder_out` before padding. - """ - raise NotImplementedError("Please implement it in a subclass") diff --git a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/encoder_interface.py b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/encoder_interface.py new file mode 120000 index 000000000..f58253127 --- /dev/null +++ b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/export.py b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/export.py deleted file mode 100755 index 513388113..000000000 --- a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/export.py +++ /dev/null @@ -1,243 +0,0 @@ -import argparse -import logging -from pathlib import Path - -import sentencepiece as spm -import torch -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import str2bool - -# python pruned_transducer_stateless5/export.py --exp-dir pruned_transducer_stateless5/exp_streaming --streaming-model 1 --causal-convolution 1 --jit 1 --epoch 10 --avg 5 --bpe-model data/lang_bpe_2000/bpe.model -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=10, - help="""It specifies the checkpoint to use for averaging. - Note: Epoch counts from 0. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=5, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless5/exp_streaming", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_2000/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=True, - help="""True to save a model after applying torch.jit.script. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--streaming-model", - type=str2bool, - default=True, - help="""Whether to export a streaming model, if the models in exp-dir - are streaming model, this should be True. - """, - ) - - add_model_arguments(parser) - - return parser - - -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - if params.streaming_model: - assert params.causal_convolution - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - model.to(device) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to("cpu") - model.eval() - - if params.jit: - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - filename = params.exp_dir / "cpu_jit.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torch.jit.script") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/export.py b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/export.py new file mode 120000 index 000000000..cf6a89299 --- /dev/null +++ b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/export.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/export.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/joiner.py b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/joiner.py deleted file mode 100644 index d5f4a7bd6..000000000 --- a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/joiner.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn as nn -from scaling import ScaledLinear -from icefall.utils import is_jit_tracing - -class Joiner(nn.Module): - def __init__( - self, - encoder_dim: int, - decoder_dim: int, - joiner_dim: int, - vocab_size: int, - ): - super().__init__() - - self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim) - self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim) - self.output_linear = ScaledLinear(joiner_dim, vocab_size) - - def forward( - self, - encoder_out: torch.Tensor, - decoder_out: torch.Tensor, - project_input: bool = True, - ) -> torch.Tensor: - """ - Args: - encoder_out: - Output from the encoder. Its shape is (N, T, s_range, C). - decoder_out: - Output from the decoder. Its shape is (N, T, s_range, C). - project_input: - If true, apply input projections encoder_proj and decoder_proj. - If this is false, it is the user's responsibility to do this - manually. - Returns: - Return a tensor of shape (N, T, s_range, C). - """ -# assert encoder_out.ndim == decoder_out.ndim == 4 -# assert encoder_out.shape[:-1] == decoder_out.shape[:-1] - -# if project_input: -# logit = self.encoder_proj(encoder_out) + self.decoder_proj( -# decoder_out -# ) - if not is_jit_tracing(): - assert encoder_out.ndim == decoder_out.ndim - assert encoder_out.ndim in (2, 4) - assert encoder_out.shape == decoder_out.shape - - if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) - else: - logit = encoder_out + decoder_out - - logit = self.output_linear(torch.tanh(logit)) - - return logit diff --git a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/joiner.py b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/joiner.py new file mode 120000 index 000000000..9052f3cbb --- /dev/null +++ b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/joiner.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/model.py b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/model.py deleted file mode 100644 index 272d06c37..000000000 --- a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/model.py +++ /dev/null @@ -1,207 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import Tuple - -import k2 -import torch -import torch.nn as nn -from encoder_interface import EncoderInterface -from scaling import ScaledLinear - -from icefall.utils import add_sos - - -class Transducer(nn.Module): - """It implements https://arxiv.org/pdf/1211.3711.pdf - "Sequence Transduction with Recurrent Neural Networks" - """ - - def __init__( - self, - encoder: EncoderInterface, - decoder: nn.Module, - joiner: nn.Module, - encoder_dim: int, - decoder_dim: int, - joiner_dim: int, - vocab_size: int, - ): - """ - Args: - encoder: - It is the transcription network in the paper. Its accepts - two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). - It returns two tensors: `logits` of shape (N, T, encoder_dm) and - `logit_lens` of shape (N,). - decoder: - It is the prediction network in the paper. Its input shape - is (N, U) and its output shape is (N, U, decoder_dim). - It should contain one attribute: `blank_id`. - joiner: - It has two inputs with shapes: (N, T, encoder_dim) and - (N, U, decoder_dim). - Its output shape is (N, T, U, vocab_size). Note that its output - contains unnormalized probs, i.e., not processed by log-softmax. - """ - super().__init__() - assert isinstance(encoder, EncoderInterface), type(encoder) - assert hasattr(decoder, "blank_id") - - self.encoder = encoder - self.decoder = decoder - self.joiner = joiner - - self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) - self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - y: k2.RaggedTensor, - prune_range: int = 5, - am_scale: float = 0.0, - lm_scale: float = 0.0, - warmup: float = 1.0, - reduction: str = "sum", - delay_penalty: float = 0.0, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - x: - A 3-D tensor of shape (N, T, C). - x_lens: - A 1-D tensor of shape (N,). It contains the number of frames in `x` - before padding. - y: - A ragged tensor with 2 axes [utt][label]. It contains labels of each - utterance. - prune_range: - The prune range for rnnt loss, it means how many symbols(context) - we are considering for each frame to compute the loss. - am_scale: - The scale to smooth the loss with am (output of encoder network) - part - lm_scale: - The scale to smooth the loss with lm (output of predictor network) - part - warmup: - A value warmup >= 0 that determines which modules are active, values - warmup > 1 "are fully warmed up" and all modules will be active. - reduction: - "sum" to sum the losses over all utterances in the batch. - "none" to return the loss in a 1-D tensor for each utterance - in the batch. - delay_penalty: - A constant value used to penalize symbol delay, to encourage - streaming models to emit symbols earlier. - See https://github.com/k2-fsa/k2/issues/955 and - https://arxiv.org/pdf/2211.00490.pdf for more details. - Returns: - Returns: - Return the transducer loss. - - Note: - Regarding am_scale & lm_scale, it will make the loss-function one of - the form: - lm_scale * lm_probs + am_scale * am_probs + - (1-lm_scale-am_scale) * combined_probs - """ - assert reduction in ("sum", "none"), reduction - assert x.ndim == 3, x.shape - assert x_lens.ndim == 1, x_lens.shape - assert y.num_axes == 2, y.num_axes - - assert x.size(0) == x_lens.size(0) == y.dim0 - - encoder_out, x_lens = self.encoder(x, x_lens, warmup=warmup) - assert torch.all(x_lens > 0) - - # Now for the decoder, i.e., the prediction network - row_splits = y.shape.row_splits(1) - y_lens = row_splits[1:] - row_splits[:-1] - - blank_id = self.decoder.blank_id - sos_y = add_sos(y, sos_id=blank_id) - - # sos_y_padded: [B, S + 1], start with SOS. - sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) - - # decoder_out: [B, S + 1, decoder_dim] - decoder_out = self.decoder(sos_y_padded) - - # Note: y does not start with SOS - # y_padded : [B, S] - y_padded = y.pad(mode="constant", padding_value=0) - - y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) - boundary[:, 2] = y_lens - boundary[:, 3] = x_lens - - lm = self.simple_lm_proj(decoder_out) - am = self.simple_am_proj(encoder_out) - - with torch.cuda.amp.autocast(enabled=False): - simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( - lm=lm.float(), - am=am.float(), - symbols=y_padded, - termination_symbol=blank_id, - lm_only_scale=lm_scale, - am_only_scale=am_scale, - boundary=boundary, - reduction=reduction, - delay_penalty=delay_penalty, - return_grad=True, - ) - - # ranges : [B, T, prune_range] - ranges = k2.get_rnnt_prune_ranges( - px_grad=px_grad, - py_grad=py_grad, - boundary=boundary, - s_range=prune_range, - ) - - # am_pruned : [B, T, prune_range, encoder_dim] - # lm_pruned : [B, T, prune_range, decoder_dim] - am_pruned, lm_pruned = k2.do_rnnt_pruning( - am=self.joiner.encoder_proj(encoder_out), - lm=self.joiner.decoder_proj(decoder_out), - ranges=ranges, - ) - - # logits : [B, T, prune_range, vocab_size] - - # project_input=False since we applied the decoder's input projections - # prior to do_rnnt_pruning (this is an optimization for speed). - logits = self.joiner(am_pruned, lm_pruned, project_input=False) - - with torch.cuda.amp.autocast(enabled=False): - pruned_loss = k2.rnnt_loss_pruned( - logits=logits.float(), - symbols=y_padded, - ranges=ranges, - termination_symbol=blank_id, - boundary=boundary, - delay_penalty=delay_penalty, - reduction=reduction, - ) - - return (simple_loss, pruned_loss) diff --git a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/model.py b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/model.py new file mode 120000 index 000000000..a99e74334 --- /dev/null +++ b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/model.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/optim.py b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/optim.py deleted file mode 100644 index 432bf8220..000000000 --- a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/optim.py +++ /dev/null @@ -1,331 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) -# -# See ../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import List, Optional, Union - -import torch -from torch.optim import Optimizer - - -class Eve(Optimizer): - r""" - Implements Eve algorithm. This is a modified version of AdamW with a special - way of setting the weight-decay / shrinkage-factor, which is designed to make the - rms of the parameters approach a particular target_rms (default: 0.1). This is - for use with networks with 'scaled' versions of modules (see scaling.py), which - will be close to invariant to the absolute scale on the parameter matrix. - - The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. - The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. - Eve is unpublished so far. - - Arguments: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay coefficient (default: 3e-4; - this value means that the weight would decay significantly after - about 3k minibatches. Is not multiplied by learning rate, but - is conditional on RMS-value of parameter being > target_rms. - target_rms (float, optional): target root-mean-square value of - parameters, if they fall below this we will stop applying weight decay. - - - .. _Adam\: A Method for Stochastic Optimization: - https://arxiv.org/abs/1412.6980 - .. _Decoupled Weight Decay Regularization: - https://arxiv.org/abs/1711.05101 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ - """ - - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.98), - eps=1e-8, - weight_decay=1e-3, - target_rms=0.1, - ): - - if not 0.0 <= lr: - raise ValueError("Invalid learning rate: {}".format(lr)) - if not 0.0 <= eps: - raise ValueError("Invalid epsilon value: {}".format(eps)) - if not 0.0 <= betas[0] < 1.0: - raise ValueError( - "Invalid beta parameter at index 0: {}".format(betas[0]) - ) - if not 0.0 <= betas[1] < 1.0: - raise ValueError( - "Invalid beta parameter at index 1: {}".format(betas[1]) - ) - if not 0 <= weight_decay <= 0.1: - raise ValueError( - "Invalid weight_decay value: {}".format(weight_decay) - ) - if not 0 < target_rms <= 10.0: - raise ValueError("Invalid target_rms value: {}".format(target_rms)) - defaults = dict( - lr=lr, - betas=betas, - eps=eps, - weight_decay=weight_decay, - target_rms=target_rms, - ) - super(Eve, self).__init__(params, defaults) - - def __setstate__(self, state): - super(Eve, self).__setstate__(state) - - @torch.no_grad() - def step(self, closure=None): - """Performs a single optimization step. - - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - for p in group["params"]: - if p.grad is None: - continue - - # Perform optimization step - grad = p.grad - if grad.is_sparse: - raise RuntimeError( - "AdamW does not support sparse gradients" - ) - - state = self.state[p] - - # State initialization - if len(state) == 0: - state["step"] = 0 - # Exponential moving average of gradient values - state["exp_avg"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) - # Exponential moving average of squared gradient values - state["exp_avg_sq"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) - - exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] - - beta1, beta2 = group["betas"] - - state["step"] += 1 - bias_correction1 = 1 - beta1 ** state["step"] - bias_correction2 = 1 - beta2 ** state["step"] - - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( - group["eps"] - ) - - step_size = group["lr"] / bias_correction1 - target_rms = group["target_rms"] - weight_decay = group["weight_decay"] - - if p.numel() > 1: - # avoid applying this weight-decay on "scaling factors" - # (which are scalar). - is_above_target_rms = p.norm() > ( - target_rms * (p.numel() ** 0.5) - ) - p.mul_(1 - (weight_decay * is_above_target_rms)) - p.addcdiv_(exp_avg, denom, value=-step_size) - - return loss - - -class LRScheduler(object): - """ - Base-class for learning rate schedulers where the learning-rate depends on both the - batch and the epoch. - """ - - def __init__(self, optimizer: Optimizer, verbose: bool = False): - # Attach optimizer - if not isinstance(optimizer, Optimizer): - raise TypeError( - "{} is not an Optimizer".format(type(optimizer).__name__) - ) - self.optimizer = optimizer - self.verbose = verbose - - for group in optimizer.param_groups: - group.setdefault("initial_lr", group["lr"]) - - self.base_lrs = [ - group["initial_lr"] for group in optimizer.param_groups - ] - - self.epoch = 0 - self.batch = 0 - - def state_dict(self): - """Returns the state of the scheduler as a :class:`dict`. - - It contains an entry for every variable in self.__dict__ which - is not the optimizer. - """ - return { - "base_lrs": self.base_lrs, - "epoch": self.epoch, - "batch": self.batch, - } - - def load_state_dict(self, state_dict): - """Loads the schedulers state. - - Args: - state_dict (dict): scheduler state. Should be an object returned - from a call to :meth:`state_dict`. - """ - self.__dict__.update(state_dict) - - def get_last_lr(self) -> List[float]: - """Return last computed learning rate by current scheduler. Will be a list of float.""" - return self._last_lr - - def get_lr(self): - # Compute list of learning rates from self.epoch and self.batch and - # self.base_lrs; this must be overloaded by the user. - # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ] - raise NotImplementedError - - def step_batch(self, batch: Optional[int] = None) -> None: - # Step the batch index, or just set it. If `batch` is specified, it - # must be the batch index from the start of training, i.e. summed over - # all epochs. - # You can call this in any order; if you don't provide 'batch', it should - # of course be called once per batch. - if batch is not None: - self.batch = batch - else: - self.batch = self.batch + 1 - self._set_lrs() - - def step_epoch(self, epoch: Optional[int] = None): - # Step the epoch index, or just set it. If you provide the 'epoch' arg, - # you should call this at the start of the epoch; if you don't provide the 'epoch' - # arg, you should call it at the end of the epoch. - if epoch is not None: - self.epoch = epoch - else: - self.epoch = self.epoch + 1 - self._set_lrs() - - def _set_lrs(self): - values = self.get_lr() - assert len(values) == len(self.optimizer.param_groups) - - for i, data in enumerate(zip(self.optimizer.param_groups, values)): - param_group, lr = data - param_group["lr"] = lr - self.print_lr(self.verbose, i, lr) - self._last_lr = [group["lr"] for group in self.optimizer.param_groups] - - def print_lr(self, is_verbose, group, lr): - """Display the current learning rate.""" - if is_verbose: - print( - f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate" - f" of group {group} to {lr:.4e}." - ) - - -class Eden(LRScheduler): - """ - Eden scheduler. - lr = initial_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 * - (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) - - E.g. suggest initial-lr = 0.003 (passed to optimizer). - - Args: - optimizer: the optimizer to change the learning rates on - lr_batches: the number of batches after which we start significantly - decreasing the learning rate, suggest 5000. - lr_epochs: the number of epochs after which we start significantly - decreasing the learning rate, suggest 6 if you plan to do e.g. - 20 to 40 epochs, but may need smaller number if dataset is huge - and you will do few epochs. - """ - - def __init__( - self, - optimizer: Optimizer, - lr_batches: Union[int, float], - lr_epochs: Union[int, float], - verbose: bool = False, - ): - super(Eden, self).__init__(optimizer, verbose) - self.lr_batches = lr_batches - self.lr_epochs = lr_epochs - - def get_lr(self): - factor = ( - (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 - ) ** -0.25 * ( - ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) - ** -0.25 - ) - return [x * factor for x in self.base_lrs] - - -def _test_eden(): - m = torch.nn.Linear(100, 100) - optim = Eve(m.parameters(), lr=0.003) - - scheduler = Eden(optim, lr_batches=30, lr_epochs=2, verbose=True) - - for epoch in range(10): - scheduler.step_epoch(epoch) # sets epoch to `epoch` - - for step in range(20): - x = torch.randn(200, 100).detach() - x.requires_grad = True - y = m(x) - dy = torch.randn(200, 100).detach() - f = (y * dy).sum() - f.backward() - - optim.step() - scheduler.step_batch() - optim.zero_grad() - print("last lr = ", scheduler.get_last_lr()) - print("state dict = ", scheduler.state_dict()) - - -if __name__ == "__main__": - _test_eden() diff --git a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/optim.py b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/optim.py new file mode 120000 index 000000000..0a2f285aa --- /dev/null +++ b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/optim.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/scaling.py b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/scaling.py deleted file mode 100644 index 5ee4bab98..000000000 --- a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/scaling.py +++ /dev/null @@ -1,719 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import collections -from itertools import repeat -from typing import Optional, Tuple - -import torch -import torch.nn as nn -from torch import Tensor - - -def _ntuple(n): - def parse(x): - if isinstance(x, collections.Iterable): - return x - return tuple(repeat(x, n)) - - return parse - - -_single = _ntuple(1) -_pair = _ntuple(2) - - -class ActivationBalancerFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x: Tensor, - channel_dim: int, - min_positive: float, # e.g. 0.05 - max_positive: float, # e.g. 0.95 - max_factor: float, # e.g. 0.01 - min_abs: float, # e.g. 0.2 - max_abs: float, # e.g. 100.0 - ) -> Tensor: - if x.requires_grad: - if channel_dim < 0: - channel_dim += x.ndim - sum_dims = [d for d in range(x.ndim) if d != channel_dim] - xgt0 = x > 0 - proportion_positive = torch.mean( - xgt0.to(x.dtype), dim=sum_dims, keepdim=True - ) - factor1 = ( - (min_positive - proportion_positive).relu() - * (max_factor / min_positive) - if min_positive != 0.0 - else 0.0 - ) - factor2 = ( - (proportion_positive - max_positive).relu() - * (max_factor / (max_positive - 1.0)) - if max_positive != 1.0 - else 0.0 - ) - factor = factor1 + factor2 - if isinstance(factor, float): - factor = torch.zeros_like(proportion_positive) - - mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True) - below_threshold = mean_abs < min_abs - above_threshold = mean_abs > max_abs - - ctx.save_for_backward( - factor, xgt0, below_threshold, above_threshold - ) - ctx.max_factor = max_factor - ctx.sum_dims = sum_dims - return x - - @staticmethod - def backward( - ctx, x_grad: Tensor - ) -> Tuple[Tensor, None, None, None, None, None, None]: - factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors - dtype = x_grad.dtype - scale_factor = ( - (below_threshold.to(dtype) - above_threshold.to(dtype)) - * (xgt0.to(dtype) - 0.5) - * (ctx.max_factor * 2.0) - ) - - neg_delta_grad = x_grad.abs() * (factor + scale_factor) - return x_grad - neg_delta_grad, None, None, None, None, None, None - - -class BasicNorm(torch.nn.Module): - """ - This is intended to be a simpler, and hopefully cheaper, replacement for - LayerNorm. The observation this is based on, is that Transformer-type - networks, especially with pre-norm, sometimes seem to set one of the - feature dimensions to a large constant value (e.g. 50), which "defeats" - the LayerNorm because the output magnitude is then not strongly dependent - on the other (useful) features. Presumably the weight and bias of the - LayerNorm are required to allow it to do this. - - So the idea is to introduce this large constant value as an explicit - parameter, that takes the role of the "eps" in LayerNorm, so the network - doesn't have to do this trick. We make the "eps" learnable. - - Args: - num_channels: the number of channels, e.g. 512. - channel_dim: the axis/dimension corresponding to the channel, - interprted as an offset from the input's ndim if negative. - shis is NOT the num_channels; it should typically be one of - {-2, -1, 0, 1, 2, 3}. - eps: the initial "epsilon" that we add as ballast in: - scale = ((input_vec**2).mean() + epsilon)**-0.5 - Note: our epsilon is actually large, but we keep the name - to indicate the connection with conventional LayerNorm. - learn_eps: if true, we learn epsilon; if false, we keep it - at the initial value. - """ - - def __init__( - self, - num_channels: int, - channel_dim: int = -1, # CAUTION: see documentation. - eps: float = 0.25, - learn_eps: bool = True, - ) -> None: - super(BasicNorm, self).__init__() - self.num_channels = num_channels - self.channel_dim = channel_dim - if learn_eps: - self.eps = nn.Parameter(torch.tensor(eps).log().detach()) - else: - self.register_buffer("eps", torch.tensor(eps).log().detach()) - - def forward(self, x: Tensor) -> Tensor: - assert x.shape[self.channel_dim] == self.num_channels - scales = ( - torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) - + self.eps.exp() - ) ** -0.5 - return x * scales - - -class ScaledLinear(nn.Linear): - """ - A modified version of nn.Linear where the parameters are scaled before - use, via: - weight = self.weight * self.weight_scale.exp() - bias = self.bias * self.bias_scale.exp() - - Args: - Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False. - - initial_scale: you can override this if you want to increase - or decrease the initial magnitude of the module's output - (affects the initialization of weight_scale and bias_scale). - Another option, if you want to do something like this, is - to re-initialize the parameters. - initial_speed: this affects how fast the parameter will - learn near the start of training; you can set it to a - value less than one if you suspect that a module - is contributing to instability near the start of training. - Nnote: regardless of the use of this option, it's best to - use schedulers like Noam that have a warm-up period. - Alternatively you can set it to more than 1 if you want it to - initially train faster. Must be greater than 0. - """ - - def __init__( - self, - *args, - initial_scale: float = 1.0, - initial_speed: float = 1.0, - **kwargs - ): - super(ScaledLinear, self).__init__(*args, **kwargs) - initial_scale = torch.tensor(initial_scale).log() - self.weight_scale = nn.Parameter(initial_scale.clone().detach()) - if self.bias is not None: - self.bias_scale = nn.Parameter(initial_scale.clone().detach()) - else: - self.register_parameter("bias_scale", None) - - self._reset_parameters( - initial_speed - ) # Overrides the reset_parameters in nn.Linear - - def _reset_parameters(self, initial_speed: float): - std = 0.1 / initial_speed - a = (3 ** 0.5) * std - nn.init.uniform_(self.weight, -a, a) - if self.bias is not None: - nn.init.constant_(self.bias, 0.0) - fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) - with torch.no_grad(): - self.weight_scale += torch.tensor(scale / std).log() - - def get_weight(self): - return self.weight * self.weight_scale.exp() - - def get_bias(self): - if self.bias is None or self.bias_scale is None: - return None - - return self.bias * self.bias_scale.exp() - - def forward(self, input: Tensor) -> Tensor: - return torch.nn.functional.linear( - input, self.get_weight(), self.get_bias() - ) - - -class ScaledConv1d(nn.Conv1d): - # See docs for ScaledLinear - def __init__( - self, - *args, - initial_scale: float = 1.0, - initial_speed: float = 1.0, - **kwargs - ): - super(ScaledConv1d, self).__init__(*args, **kwargs) - initial_scale = torch.tensor(initial_scale).log() - self.weight_scale = nn.Parameter(initial_scale.clone().detach()) - if self.bias is not None: - self.bias_scale = nn.Parameter(initial_scale.clone().detach()) - else: - self.register_parameter("bias_scale", None) - self._reset_parameters( - initial_speed - ) # Overrides the reset_parameters in base class - - def _reset_parameters(self, initial_speed: float): - std = 0.1 / initial_speed - a = (3 ** 0.5) * std - nn.init.uniform_(self.weight, -a, a) - if self.bias is not None: - nn.init.constant_(self.bias, 0.0) - fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) - with torch.no_grad(): - self.weight_scale += torch.tensor(scale / std).log() - - def get_weight(self): - return self.weight * self.weight_scale.exp() - - def get_bias(self): - bias = self.bias - bias_scale = self.bias_scale - if bias is None or bias_scale is None: - return None - return bias * bias_scale.exp() - - def forward(self, input: Tensor) -> Tensor: - F = torch.nn.functional - if self.padding_mode != "zeros": - return F.conv1d( - F.pad( - input, - self._reversed_padding_repeated_twice, - mode=self.padding_mode, - ), - self.get_weight(), - self.get_bias(), - self.stride, - (0,), - self.dilation, - self.groups, - ) - return F.conv1d( - input, - self.get_weight(), - self.get_bias(), - self.stride, - self.padding, - self.dilation, - self.groups, - ) - - -class ScaledConv2d(nn.Conv2d): - # See docs for ScaledLinear - def __init__( - self, - *args, - initial_scale: float = 1.0, - initial_speed: float = 1.0, - **kwargs - ): - super(ScaledConv2d, self).__init__(*args, **kwargs) - initial_scale = torch.tensor(initial_scale).log() - self.weight_scale = nn.Parameter(initial_scale.clone().detach()) - if self.bias is not None: - self.bias_scale = nn.Parameter(initial_scale.clone().detach()) - else: - self.register_parameter("bias_scale", None) - self._reset_parameters( - initial_speed - ) # Overrides the reset_parameters in base class - - def _reset_parameters(self, initial_speed: float): - std = 0.1 / initial_speed - a = (3 ** 0.5) * std - nn.init.uniform_(self.weight, -a, a) - if self.bias is not None: - nn.init.constant_(self.bias, 0.0) - fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) - with torch.no_grad(): - self.weight_scale += torch.tensor(scale / std).log() - - def get_weight(self): - return self.weight * self.weight_scale.exp() - - def get_bias(self): - # see https://github.com/pytorch/pytorch/issues/24135 - bias = self.bias - bias_scale = self.bias_scale - if bias is None or bias_scale is None: - return None - return bias * bias_scale.exp() - - def _conv_forward(self, input, weight): - F = torch.nn.functional - if self.padding_mode != "zeros": - return F.conv2d( - F.pad( - input, - self._reversed_padding_repeated_twice, - mode=self.padding_mode, - ), - weight, - self.get_bias(), - self.stride, - (0, 0), - self.dilation, - self.groups, - ) - return F.conv2d( - input, - weight, - self.get_bias(), - self.stride, - self.padding, - self.dilation, - self.groups, - ) - - def forward(self, input: Tensor) -> Tensor: - return self._conv_forward(input, self.get_weight()) - - -class ActivationBalancer(torch.nn.Module): - """ - Modifies the backpropped derivatives of a function to try to encourage, for - each channel, that it is positive at least a proportion `threshold` of the - time. It does this by multiplying negative derivative values by up to - (1+max_factor), and positive derivative values by up to (1-max_factor), - interpolated from 1 at the threshold to those extremal values when none - of the inputs are positive. - - - Args: - channel_dim: the dimension/axis corresponding to the channel, e.g. - -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. - min_positive: the minimum, per channel, of the proportion of the time - that (x > 0), below which we start to modify the derivatives. - max_positive: the maximum, per channel, of the proportion of the time - that (x > 0), above which we start to modify the derivatives. - max_factor: the maximum factor by which we modify the derivatives for - either the sign constraint or the magnitude constraint; - e.g. with max_factor=0.02, the the derivatives would be multiplied by - values in the range [0.98..1.02]. - min_abs: the minimum average-absolute-value per channel, which - we allow, before we start to modify the derivatives to prevent - this. - max_abs: the maximum average-absolute-value per channel, which - we allow, before we start to modify the derivatives to prevent - this. - """ - - def __init__( - self, - channel_dim: int, - min_positive: float = 0.05, - max_positive: float = 0.95, - max_factor: float = 0.01, - min_abs: float = 0.2, - max_abs: float = 100.0, - ): - super(ActivationBalancer, self).__init__() - self.channel_dim = channel_dim - self.min_positive = min_positive - self.max_positive = max_positive - self.max_factor = max_factor - self.min_abs = min_abs - self.max_abs = max_abs - - def forward(self, x: Tensor) -> Tensor: - if torch.jit.is_scripting(): - return x - - return ActivationBalancerFunction.apply( - x, - self.channel_dim, - self.min_positive, - self.max_positive, - self.max_factor, - self.min_abs, - self.max_abs, - ) - - -class DoubleSwishFunction(torch.autograd.Function): - """ - double_swish(x) = x * torch.sigmoid(x-1) - This is a definition, originally motivated by its close numerical - similarity to swish(swish(x)), where swish(x) = x * sigmoid(x). - - Memory-efficient derivative computation: - double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) - double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). - Now, s'(x) = s(x) * (1-s(x)). - double_swish'(x) = x * s'(x) + s(x). - = x * s(x) * (1-s(x)) + s(x). - = double_swish(x) * (1-s(x)) + s(x) - ... so we just need to remember s(x) but not x itself. - """ - - @staticmethod - def forward(ctx, x: Tensor) -> Tensor: - x = x.detach() - s = torch.sigmoid(x - 1.0) - y = x * s - ctx.save_for_backward(s, y) - return y - - @staticmethod - def backward(ctx, y_grad: Tensor) -> Tensor: - s, y = ctx.saved_tensors - return (y * (1 - s) + s) * y_grad - - -class DoubleSwish(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: - """Return double-swish activation function which is an approximation to Swish(Swish(x)), - that we approximate closely with x * sigmoid(x-1). - """ - if torch.jit.is_scripting(): - return x * torch.sigmoid(x - 1.0) - return DoubleSwishFunction.apply(x) - - -class ScaledEmbedding(nn.Module): - r"""This is a modified version of nn.Embedding that introduces a learnable scale - on the parameters. Note: due to how we initialize it, it's best used with - schedulers like Noam that have a warmup period. - - It is a simple lookup table that stores embeddings of a fixed dictionary and size. - - This module is often used to store word embeddings and retrieve them using indices. - The input to the module is a list of indices, and the output is the corresponding - word embeddings. - - Args: - num_embeddings (int): size of the dictionary of embeddings - embedding_dim (int): the size of each embedding vector - padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx` - (initialized to zeros) whenever it encounters the index. - max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` - is renormalized to have norm :attr:`max_norm`. - norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. - scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of - the words in the mini-batch. Default ``False``. - sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. - See Notes for more details regarding sparse gradients. - - initial_speed (float, optional): This affects how fast the parameter will - learn near the start of training; you can set it to a value less than - one if you suspect that a module is contributing to instability near - the start of training. Nnote: regardless of the use of this option, - it's best to use schedulers like Noam that have a warm-up period. - Alternatively you can set it to more than 1 if you want it to - initially train faster. Must be greater than 0. - - - Attributes: - weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim) - initialized from :math:`\mathcal{N}(0, 1)` - - Shape: - - Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract - - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}` - - .. note:: - Keep in mind that only a limited number of optimizers support - sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`), - :class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`) - - .. note:: - With :attr:`padding_idx` set, the embedding vector at - :attr:`padding_idx` is initialized to all zeros. However, note that this - vector can be modified afterwards, e.g., using a customized - initialization method, and thus changing the vector used to pad the - output. The gradient for this vector from :class:`~torch.nn.Embedding` - is always zero. - - Examples:: - - >>> # an Embedding module containing 10 tensors of size 3 - >>> embedding = nn.Embedding(10, 3) - >>> # a batch of 2 samples of 4 indices each - >>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]]) - >>> embedding(input) - tensor([[[-0.0251, -1.6902, 0.7172], - [-0.6431, 0.0748, 0.6969], - [ 1.4970, 1.3448, -0.9685], - [-0.3677, -2.7265, -0.1685]], - - [[ 1.4970, 1.3448, -0.9685], - [ 0.4362, -0.4004, 0.9400], - [-0.6431, 0.0748, 0.6969], - [ 0.9124, -2.3616, 1.1151]]]) - - - >>> # example with padding_idx - >>> embedding = nn.Embedding(10, 3, padding_idx=0) - >>> input = torch.LongTensor([[0,2,0,5]]) - >>> embedding(input) - tensor([[[ 0.0000, 0.0000, 0.0000], - [ 0.1535, -2.0309, 0.9315], - [ 0.0000, 0.0000, 0.0000], - [-0.1655, 0.9897, 0.0635]]]) - - """ - __constants__ = [ - "num_embeddings", - "embedding_dim", - "padding_idx", - "scale_grad_by_freq", - "sparse", - ] - - num_embeddings: int - embedding_dim: int - padding_idx: int - scale_grad_by_freq: bool - weight: Tensor - sparse: bool - - def __init__( - self, - num_embeddings: int, - embedding_dim: int, - padding_idx: Optional[int] = None, - scale_grad_by_freq: bool = False, - sparse: bool = False, - initial_speed: float = 1.0, - ) -> None: - super(ScaledEmbedding, self).__init__() - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - if padding_idx is not None: - if padding_idx > 0: - assert ( - padding_idx < self.num_embeddings - ), "Padding_idx must be within num_embeddings" - elif padding_idx < 0: - assert ( - padding_idx >= -self.num_embeddings - ), "Padding_idx must be within num_embeddings" - padding_idx = self.num_embeddings + padding_idx - self.padding_idx = padding_idx - self.scale_grad_by_freq = scale_grad_by_freq - - self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() - self.sparse = sparse - - self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) - self.reset_parameters(initial_speed) - - def reset_parameters(self, initial_speed: float = 1.0) -> None: - std = 0.1 / initial_speed - nn.init.normal_(self.weight, std=std) - nn.init.constant_(self.scale, torch.tensor(1.0 / std).log()) - - if self.padding_idx is not None: - with torch.no_grad(): - self.weight[self.padding_idx].fill_(0) - - def forward(self, input: Tensor) -> Tensor: - F = torch.nn.functional - scale = self.scale.exp() - if input.numel() < self.num_embeddings: - return ( - F.embedding( - input, - self.weight, - self.padding_idx, - None, - 2.0, # None, 2.0 relate to normalization - self.scale_grad_by_freq, - self.sparse, - ) - * scale - ) - else: - return F.embedding( - input, - self.weight * scale, - self.padding_idx, - None, - 2.0, # None, 2.0 relates to normalization - self.scale_grad_by_freq, - self.sparse, - ) - - def extra_repr(self) -> str: - s = "{num_embeddings}, {embedding_dim}, scale={scale}" - if self.padding_idx is not None: - s += ", padding_idx={padding_idx}" - if self.scale_grad_by_freq is not False: - s += ", scale_grad_by_freq={scale_grad_by_freq}" - if self.sparse is not False: - s += ", sparse=True" - return s.format(**self.__dict__) - - -def _test_activation_balancer_sign(): - probs = torch.arange(0, 1, 0.01) - N = 1000 - x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) - x = x.detach() - x.requires_grad = True - m = ActivationBalancer( - channel_dim=0, - min_positive=0.05, - max_positive=0.95, - max_factor=0.2, - min_abs=0.0, - ) - - y_grad = torch.sign(torch.randn(probs.numel(), N)) - - y = m(x) - y.backward(gradient=y_grad) - print("_test_activation_balancer_sign: x = ", x) - print("_test_activation_balancer_sign: y grad = ", y_grad) - print("_test_activation_balancer_sign: x grad = ", x.grad) - - -def _test_activation_balancer_magnitude(): - magnitudes = torch.arange(0, 1, 0.01) - N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( - -1 - ) - x = x.detach() - x.requires_grad = True - m = ActivationBalancer( - channel_dim=0, - min_positive=0.0, - max_positive=1.0, - max_factor=0.2, - min_abs=0.2, - max_abs=0.8, - ) - - y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) - - y = m(x) - y.backward(gradient=y_grad) - print("_test_activation_balancer_magnitude: x = ", x) - print("_test_activation_balancer_magnitude: y grad = ", y_grad) - print("_test_activation_balancer_magnitude: x grad = ", x.grad) - - -def _test_basic_norm(): - num_channels = 128 - m = BasicNorm(num_channels=num_channels, channel_dim=1) - - x = torch.randn(500, num_channels) - - y = m(x) - - assert y.shape == x.shape - x_rms = (x ** 2).mean().sqrt() - y_rms = (y ** 2).mean().sqrt() - print("x rms = ", x_rms) - print("y rms = ", y_rms) - assert y_rms < x_rms - assert y_rms > 0.5 * x_rms - - -def _test_double_swish_deriv(): - x = torch.randn(10, 12, dtype=torch.double) * 0.5 - x.requires_grad = True - m = DoubleSwish() - torch.autograd.gradcheck(m, x) - - -if __name__ == "__main__": - _test_activation_balancer_sign() - _test_activation_balancer_magnitude() - _test_basic_norm() - _test_double_swish_deriv() diff --git a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/scaling.py b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/scaling.py new file mode 120000 index 000000000..ff7bfeda9 --- /dev/null +++ b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless5/scaling.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/scaling_converter.py b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/scaling_converter.py deleted file mode 100644 index 06a81656c..000000000 --- a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/scaling_converter.py +++ /dev/null @@ -1,314 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file provides functions to convert `ScaledLinear`, `ScaledConv1d`, -`ScaledConv2d`, and `ScaledEmbedding` to their non-scaled counterparts: -`nn.Linear`, `nn.Conv1d`, `nn.Conv2d`, and `nn.Embedding`. -The scaled version are required only in the training time. It simplifies our -life by converting them to their non-scaled version during inference. -""" - -import copy -import re -from typing import List - -import torch -import torch.nn as nn -from lstmp import LSTMP -from scaling import ( - ActivationBalancer, - BasicNorm, - ScaledConv1d, - ScaledConv2d, - ScaledEmbedding, - ScaledLinear, - ScaledLSTM, -) - - -class NonScaledNorm(nn.Module): - """See BasicNorm for doc""" - - def __init__( - self, - num_channels: int, - eps_exp: float, - channel_dim: int = -1, # CAUTION: see documentation. - ): - super().__init__() - self.num_channels = num_channels - self.channel_dim = channel_dim - self.eps_exp = eps_exp - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if not torch.jit.is_tracing(): - assert x.shape[self.channel_dim] == self.num_channels - scales = ( - torch.mean(x * x, dim=self.channel_dim, - keepdim=True) + self.eps_exp - ).pow(-0.5) - return x * scales - - -def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> nn.Linear: - """Convert an instance of ScaledLinear to nn.Linear. - Args: - scaled_linear: - The layer to be converted. - Returns: - Return a linear layer. It satisfies: - scaled_linear(x) == linear(x) - for any given input tensor `x`. - """ - assert isinstance(scaled_linear, ScaledLinear), type(scaled_linear) - - weight = scaled_linear.get_weight() - bias = scaled_linear.get_bias() - has_bias = bias is not None - - linear = torch.nn.Linear( - in_features=scaled_linear.in_features, - out_features=scaled_linear.out_features, - bias=True, # otherwise, it throws errors when converting to PNNX format - # device=weight.device, # Pytorch version before v1.9.0 does not have - # this argument. Comment out for now, we will - # see if it will raise error for versions - # after v1.9.0 - ) - linear.weight.data.copy_(weight) - - if has_bias: - linear.bias.data.copy_(bias) - else: - linear.bias.data.zero_() - - return linear - - -def scaled_conv1d_to_conv1d(scaled_conv1d: ScaledConv1d) -> nn.Conv1d: - """Convert an instance of ScaledConv1d to nn.Conv1d. - Args: - scaled_conv1d: - The layer to be converted. - Returns: - Return an instance of nn.Conv1d that has the same `forward()` behavior - of the given `scaled_conv1d`. - """ - assert isinstance(scaled_conv1d, ScaledConv1d), type(scaled_conv1d) - - weight = scaled_conv1d.get_weight() - bias = scaled_conv1d.get_bias() - has_bias = bias is not None - - conv1d = nn.Conv1d( - in_channels=scaled_conv1d.in_channels, - out_channels=scaled_conv1d.out_channels, - kernel_size=scaled_conv1d.kernel_size, - stride=scaled_conv1d.stride, - padding=scaled_conv1d.padding, - dilation=scaled_conv1d.dilation, - groups=scaled_conv1d.groups, - bias=scaled_conv1d.bias is not None, - padding_mode=scaled_conv1d.padding_mode, - ) - - conv1d.weight.data.copy_(weight) - if has_bias: - conv1d.bias.data.copy_(bias) - - return conv1d - - -def scaled_conv2d_to_conv2d(scaled_conv2d: ScaledConv2d) -> nn.Conv2d: - """Convert an instance of ScaledConv2d to nn.Conv2d. - Args: - scaled_conv2d: - The layer to be converted. - Returns: - Return an instance of nn.Conv2d that has the same `forward()` behavior - of the given `scaled_conv2d`. - """ - assert isinstance(scaled_conv2d, ScaledConv2d), type(scaled_conv2d) - - weight = scaled_conv2d.get_weight() - bias = scaled_conv2d.get_bias() - has_bias = bias is not None - - conv2d = nn.Conv2d( - in_channels=scaled_conv2d.in_channels, - out_channels=scaled_conv2d.out_channels, - kernel_size=scaled_conv2d.kernel_size, - stride=scaled_conv2d.stride, - padding=scaled_conv2d.padding, - dilation=scaled_conv2d.dilation, - groups=scaled_conv2d.groups, - bias=scaled_conv2d.bias is not None, - padding_mode=scaled_conv2d.padding_mode, - ) - - conv2d.weight.data.copy_(weight) - if has_bias: - conv2d.bias.data.copy_(bias) - - return conv2d - - -def scaled_embedding_to_embedding( - scaled_embedding: ScaledEmbedding, -) -> nn.Embedding: - """Convert an instance of ScaledEmbedding to nn.Embedding. - Args: - scaled_embedding: - The layer to be converted. - Returns: - Return an instance of nn.Embedding that has the same `forward()` behavior - of the given `scaled_embedding`. - """ - assert isinstance(scaled_embedding, ScaledEmbedding), type( - scaled_embedding) - embedding = nn.Embedding( - num_embeddings=scaled_embedding.num_embeddings, - embedding_dim=scaled_embedding.embedding_dim, - padding_idx=scaled_embedding.padding_idx, - scale_grad_by_freq=scaled_embedding.scale_grad_by_freq, - sparse=scaled_embedding.sparse, - ) - weight = scaled_embedding.weight - scale = scaled_embedding.scale - - embedding.weight.data.copy_(weight * scale.exp()) - - return embedding - - -def convert_basic_norm(basic_norm: BasicNorm) -> NonScaledNorm: - assert isinstance(basic_norm, BasicNorm), type(BasicNorm) - norm = NonScaledNorm( - num_channels=basic_norm.num_channels, - eps_exp=basic_norm.eps.data.exp().item(), - channel_dim=basic_norm.channel_dim, - ) - return norm - - -def scaled_lstm_to_lstm(scaled_lstm: ScaledLSTM) -> nn.LSTM: - """Convert an instance of ScaledLSTM to nn.LSTM. - Args: - scaled_lstm: - The layer to be converted. - Returns: - Return an instance of nn.LSTM that has the same `forward()` behavior - of the given `scaled_lstm`. - """ - assert isinstance(scaled_lstm, ScaledLSTM), type(scaled_lstm) - lstm = nn.LSTM( - input_size=scaled_lstm.input_size, - hidden_size=scaled_lstm.hidden_size, - num_layers=scaled_lstm.num_layers, - bias=scaled_lstm.bias, - batch_first=scaled_lstm.batch_first, - dropout=scaled_lstm.dropout, - bidirectional=scaled_lstm.bidirectional, - proj_size=scaled_lstm.proj_size, - ) - - assert lstm._flat_weights_names == scaled_lstm._flat_weights_names - for idx in range(len(scaled_lstm._flat_weights_names)): - scaled_weight = scaled_lstm._flat_weights[idx] * \ - scaled_lstm._scales[idx].exp() - lstm._flat_weights[idx].data.copy_(scaled_weight) - - return lstm - - -# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa -# get_submodule was added to nn.Module at v1.9.0 -def get_submodule(model, target): - if target == "": - return model - atoms: List[str] = target.split(".") - mod: torch.nn.Module = model - for item in atoms: - if not hasattr(mod, item): - raise AttributeError( - mod._get_name() + " has no " "attribute `" + item + "`" - ) - mod = getattr(mod, item) - if not isinstance(mod, torch.nn.Module): - raise AttributeError("`" + item + "` is not " "an nn.Module") - return mod - - -def convert_scaled_to_non_scaled( - model: nn.Module, - inplace: bool = False, - is_onnx: bool = False, -): - """Convert `ScaledLinear`, `ScaledConv1d`, and `ScaledConv2d` - in the given modle to their unscaled version `nn.Linear`, `nn.Conv1d`, - and `nn.Conv2d`. - Args: - model: - The model to be converted. - inplace: - If True, the input model is modified inplace. - If False, the input model is copied and we modify the copied version. - is_onnx: - If True, we are going to export the model to ONNX. In this case, - we will convert nn.LSTM with proj_size to LSTMP. - Return: - Return a model without scaled layers. - """ - if not inplace: - model = copy.deepcopy(model) - - excluded_patterns = r"(self|src)_attn\.(in|out)_proj" - p = re.compile(excluded_patterns) - - d = {} - for name, m in model.named_modules(): - if isinstance(m, ScaledLinear): - if p.search(name) is not None: - continue - d[name] = scaled_linear_to_linear(m) - elif isinstance(m, ScaledConv1d): - d[name] = scaled_conv1d_to_conv1d(m) - elif isinstance(m, ScaledConv2d): - d[name] = scaled_conv2d_to_conv2d(m) - elif isinstance(m, ScaledEmbedding): - d[name] = scaled_embedding_to_embedding(m) - elif isinstance(m, BasicNorm): - d[name] = convert_basic_norm(m) - elif isinstance(m, ScaledLSTM): - if is_onnx: - d[name] = LSTMP(scaled_lstm_to_lstm(m)) - # See - # https://github.com/pytorch/pytorch/issues/47887 - # d[name] = torch.jit.script(LSTMP(scaled_lstm_to_lstm(m))) - else: - d[name] = scaled_lstm_to_lstm(m) - elif isinstance(m, ActivationBalancer): - d[name] = nn.Identity() - - for k, v in d.items(): - if "." in k: - parent, child = k.rsplit(".", maxsplit=1) - setattr(get_submodule(model, parent), child, v) - else: - setattr(model, k, v) - - return model diff --git a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/scaling_converter.py b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/scaling_converter.py new file mode 120000 index 000000000..e58473a04 --- /dev/null +++ b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless5/scaling_converter.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/streaming_beam_search.py b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/streaming_beam_search.py deleted file mode 100644 index e6e0fb1c8..000000000 --- a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/streaming_beam_search.py +++ /dev/null @@ -1,282 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings -from typing import List - -import k2 -import torch -import torch.nn as nn -from beam_search import Hypothesis, HypothesisList, get_hyps_shape -from decode_stream import DecodeStream - -from icefall.decode import one_best_decoding -from icefall.utils import get_texts - - -def greedy_search( - model: nn.Module, - encoder_out: torch.Tensor, - streams: List[DecodeStream], -) -> None: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C), where N >= 1. - streams: - A list of Stream objects. - """ - assert len(streams) == encoder_out.size(0) - assert encoder_out.ndim == 3 - - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - device = model.device - T = encoder_out.size(1) - - decoder_input = torch.tensor( - [stream.hyp[-context_size:] for stream in streams], - device=device, - dtype=torch.int64, - ) - # decoder_out is of shape (N, 1, decoder_out_dim) - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - for t in range(T): - # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) - current_encoder_out = encoder_out[:, t : t + 1, :] # noqa - - logits = model.joiner( - current_encoder_out.unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - # logits'shape (batch_size, vocab_size) - logits = logits.squeeze(1).squeeze(1) - - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v != blank_id: - streams[i].hyp.append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = torch.tensor( - [stream.hyp[-context_size:] for stream in streams], - device=device, - dtype=torch.int64, - ) - decoder_out = model.decoder( - decoder_input, - need_pad=False, - ) - decoder_out = model.joiner.decoder_proj(decoder_out) - - -def modified_beam_search( - model: nn.Module, - encoder_out: torch.Tensor, - streams: List[DecodeStream], - num_active_paths: int = 4, -) -> None: - """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. - - Args: - model: - The RNN-T model. - encoder_out: - A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of - the encoder model. - streams: - A list of stream objects. - num_active_paths: - Number of active paths during the beam search. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert len(streams) == encoder_out.size(0) - - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - device = next(model.parameters()).device - batch_size = len(streams) - T = encoder_out.size(1) - - B = [stream.hyps for stream in streams] - - for t in range(T): - current_encoder_out = encoder_out[:, t].unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.stack( - [hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0 - ) # (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out is of shape (num_hyps, 1, 1, decoder_output_dim) - - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor - # as index, so we use `to(torch.int64)` below. - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, encoder_out_dim) - - logits = model.joiner(current_encoder_out, decoder_out, project_input=False) - # logits is of shape (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) - - log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_ys = hyp.ys[:] - new_token = topk_token_indexes[k] - if new_token != blank_id: - new_ys.append(new_token) - - new_log_prob = topk_log_probs[k] - new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) - B[i].add(new_hyp) - - for i in range(batch_size): - streams[i].hyps = B[i] - - -def fast_beam_search_one_best( - model: nn.Module, - encoder_out: torch.Tensor, - processed_lens: torch.Tensor, - streams: List[DecodeStream], - beam: float, - max_states: int, - max_contexts: int, -) -> None: - """It limits the maximum number of symbols per frame to 1. - - A lattice is first generated by Fsa-based beam search, then we get the - recognition by applying shortest path on the lattice. - - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - processed_lens: - A tensor of shape (N,) containing the number of processed frames - in `encoder_out` before padding. - streams: - A list of stream objects. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - """ - assert encoder_out.ndim == 3 - B, T, C = encoder_out.shape - assert B == len(streams) - - context_size = model.decoder.context_size - vocab_size = model.decoder.vocab_size - - config = k2.RnntDecodingConfig( - vocab_size=vocab_size, - decoder_history_len=context_size, - beam=beam, - max_contexts=max_contexts, - max_states=max_states, - ) - individual_streams = [] - for i in range(B): - individual_streams.append(streams[i].rnnt_decoding_stream) - decoding_streams = k2.RnntDecodingStreams(individual_streams, config) - - for t in range(T): - # shape is a RaggedShape of shape (B, context) - # contexts is a Tensor of shape (shape.NumElements(), context_size) - shape, contexts = decoding_streams.get_contexts() - # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 - contexts = contexts.to(torch.int64) - # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) - decoder_out = model.decoder(contexts, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - # current_encoder_out is of shape - # (shape.NumElements(), 1, joiner_dim) - # fmt: off - current_encoder_out = torch.index_select( - encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) - ) - # fmt: on - logits = model.joiner( - current_encoder_out.unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - logits = logits.squeeze(1).squeeze(1) - log_probs = logits.log_softmax(dim=-1) - decoding_streams.advance(log_probs) - - decoding_streams.terminate_and_flush_to_streams() - - lattice = decoding_streams.format_output(processed_lens.tolist()) - best_path = one_best_decoding(lattice) - hyp_tokens = get_texts(best_path) - - for i in range(B): - streams[i].hyp = hyp_tokens[i] diff --git a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/streaming_beam_search.py b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/streaming_beam_search.py new file mode 120000 index 000000000..2f76638ac --- /dev/null +++ b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/streaming_beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/streaming_decode.py b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/streaming_decode.py deleted file mode 100755 index 0a61c9493..000000000 --- a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/streaming_decode.py +++ /dev/null @@ -1,608 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Usage: -./pruned_transducer_stateless/streaming_decode.py \ - --epoch 28 \ - --avg 15 \ - --decode-chunk-size 8 \ - --left-context 32 \ - --right-context 0 \ - --exp-dir ./pruned_transducer_stateless/exp \ - --decoding_method greedy_search \ - --num-decode-streams 1000 -""" - -import argparse -import logging -import math -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import numpy as np -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import MGB2AsrDataModule -from decode_stream import DecodeStream -from kaldifeat import Fbank, FbankOptions -from lhotse import CutSet -from streaming_beam_search import ( - fast_beam_search_one_best, - greedy_search, - modified_beam_search, -) -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.utils import ( - AttributeDict, - str2bool, - setup_logger, - store_transcripts, - write_error_stats, -) -import pdb - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 0. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless5/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_2000/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Supported decoding methods are: - greedy_search - modified_beam_search - fast_beam_search - """, - ) - - parser.add_argument( - "--num-active-paths", - type=int, - default=4, - help="""An interger indicating how many candidates we will keep for each - frame. Used only when --decoding-method is modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=32, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--decode-chunk-size", - type=int, - default=16, - help="The chunk size for decoding (in frames after subsampling)", - ) - - parser.add_argument( - "--left-context", - type=int, - default=64, - help="left context can be seen during decoding (in frames after subsampling)", - ) - - parser.add_argument( - "--right-context", - type=int, - default=0, - help="right context can be seen during decoding (in frames after subsampling)", - ) - - parser.add_argument( - "--num-decode-streams", - type=int, - default=2000, - help="The number of streams that can be decoded parallel.", - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_chunk( - params: AttributeDict, - model: nn.Module, - decode_streams: List[DecodeStream], -) -> List[int]: - """Decode one chunk frames of features for each decode_streams and - return the indexes of finished streams in a List. - - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - decode_streams: - A List of DecodeStream, each belonging to a utterance. - Returns: - Return a List containing which DecodeStreams are finished. - """ - device = model.device - - features = [] - feature_lens = [] - states = [] - processed_lens = [] - - for stream in decode_streams: - feat, feat_len = stream.get_feature_frames( - params.decode_chunk_size * params.subsampling_factor - ) - features.append(feat) - feature_lens.append(feat_len) - states.append(stream.states) - processed_lens.append(stream.done_frames) - - feature_lens = torch.tensor(feature_lens, device=device) - features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) - - # if T is less than 7 there will be an error in time reduction layer, - # because we subsample features with ((x_len - 1) // 2 - 1) // 2 - # we plus 2 here because we will cut off one frame on each size of - # encoder_embed output as they see invalid paddings. so we need extra 2 - # frames. - tail_length = 7 + (2 + params.right_context) * params.subsampling_factor - if features.size(1) < tail_length: - pad_length = tail_length - features.size(1) - feature_lens += pad_length - features = torch.nn.functional.pad( - features, - (0, 0, 0, pad_length), - mode="constant", - value=LOG_EPS, - ) - - states = [ - torch.stack([x[0] for x in states], dim=2), - torch.stack([x[1] for x in states], dim=2), - ] - - processed_lens = torch.tensor(processed_lens, device=device) - - encoder_out, encoder_out_lens, states = model.encoder.streaming_forward( - x=features, - x_lens=feature_lens, - states=states, - left_context=params.left_context, - right_context=params.right_context, - processed_lens=processed_lens, - ) - - if params.decoding_method == "greedy_search": - greedy_search(model=model, encoder_out=encoder_out, - streams=decode_streams) - elif params.decoding_method == "fast_beam_search": - processed_lens = processed_lens + encoder_out_lens - fast_beam_search_one_best( - model=model, - encoder_out=encoder_out, - processed_lens=processed_lens, - streams=decode_streams, - beam=params.beam, - max_states=params.max_states, - max_contexts=params.max_contexts, - ) - elif params.decoding_method == "modified_beam_search": - modified_beam_search( - model=model, - streams=decode_streams, - encoder_out=encoder_out, - num_active_paths=params.num_active_paths, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}") - - states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] - - finished_streams = [] - for i in range(len(decode_streams)): - decode_streams[i].states = [states[0][i], states[1][i]] - decode_streams[i].done_frames += encoder_out_lens[i] - if decode_streams[i].done: - finished_streams.append(i) - - return finished_streams - - -def decode_dataset( - cuts: CutSet, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: - """Decode dataset. - - Args: - cuts: - Lhotse Cutset containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - device = model.device - - opts = FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = 16000 - opts.mel_opts.num_bins = 80 - - log_interval = 100 - - decode_results = [] - # Contain decode streams currently running. - decode_streams = [] - initial_states = model.encoder.get_init_state( - params.left_context, device=device) - for num, cut_ in enumerate(cuts): - # each utterance has a DecodeStream. - for cut in cut_["supervisions"]["cut"]: - # pdb.set_trace() - decode_stream = DecodeStream( - params=params, - cut_id=cut.id, - initial_states=initial_states, - decoding_graph=decoding_graph, - device=device, - ) - - audio: np.ndarray = cut.load_audio() - # audio.shape: (1, num_samples) - assert len(audio.shape) == 2 - assert audio.shape[0] == 1, "Should be single channel" - assert audio.dtype == np.float32, audio.dtype - - # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" - - samples = torch.from_numpy(audio).squeeze(0) - - fbank = Fbank(opts) - decode_stream.set_features(fbank(samples.to(device))) - decode_stream.ground_truth = cut.supervisions[0].text - - decode_streams.append(decode_stream) - - while len(decode_streams) >= params.num_decode_streams: - # pdb.set_trace() - finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams - ) - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - decode_streams[i].ground_truth.split(), - sp.decode( - decode_streams[i].decoding_result()).split(), - ) - ) - del decode_streams[i] - - if num % log_interval == 0: - logging.info(f"Cuts processed until now is {num}.") - - # decode final chunks of last sequences - while len(decode_streams): - finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams - ) - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - decode_streams[i].ground_truth.split(), - sp.decode( - decode_streams[i].decoding_result()).split(), - ) - ) - del decode_streams[i] - - if params.decoding_method == "greedy_search": - key = "greedy_search" - elif params.decoding_method == "fast_beam_search": - key = ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ) - elif params.decoding_method == "modified_beam_search": - key = f"num_active_paths_{params.num_active_paths}" - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}") - - return {key: decode_results} - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / - f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - # sort results so we can easily compare the difference between two - # recognition results - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / - f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - MGB2AsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - params.res_dir = params.exp_dir / "streaming" / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - # for streaming - params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" - params.suffix += f"-left-context-{params.left_context}" - params.suffix += f"-right-context-{params.right_context}" - - # for fast_beam_search - if params.decoding_method == "fast_beam_search": - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - params.causal_convolution = True - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - - model.to(device) - model.eval() - model.device = device - - decoding_graph = None - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - MGB2 = MGB2AsrDataModule(args) - - test_cuts = MGB2.test_cuts() - dev_cuts = MGB2.dev_cuts() - - test_dl = MGB2.test_dataloaders(test_cuts) - dev_dl = MGB2.test_dataloaders(dev_cuts) - - test_sets = ["test", "dev"] - test_all_dl = [test_dl, dev_dl] - - for test_set, test_dl in zip(test_sets, test_all_dl): - results_dict = decode_dataset( - cuts=test_dl, - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/streaming_decode.py b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/streaming_decode.py new file mode 120000 index 000000000..f29284163 --- /dev/null +++ b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/streaming_decode.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/train_st.py b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/train.py similarity index 98% rename from egs/iwslt22_ta/ST/pruned_transducer_stateless5/train_st.py rename to egs/iwslt22_ta/ST/pruned_transducer_stateless5/train.py index 54a9dc51e..5801fde9b 100755 --- a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/train_st.py +++ b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/train.py @@ -17,12 +17,12 @@ """ Usage: -# ./pruned_transducer_stateless5/train_st.py \ +# ./pruned_transducer_stateless5/train.py \ # --world-size 4 \ # --num-epochs 20 \ # --start-epoch 1 \ # --exp-dir pruned_transducer_stateless5/exp_st \ -# --max-duration 150 \ +# --max-duration 300 \ # --bucketing-sampler 1\ # --num-buckets 50 @@ -225,12 +225,6 @@ def get_parser(): """, ) - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_ta_1000/bpe.model", - help="Path to source data BPE model", - ) parser.add_argument( "--bpe-tgt-model", type=str, @@ -656,7 +650,7 @@ def compute_loss( feature_lens = supervisions["num_frames"].to(device) #pdb.set_trace() texts = batch["supervisions"]["text"] - tgt_texts = batch["supervisions"]["tgt_text"] + tgt_texts = batch["supervisions"]["tgt_text"]['eng'] y = sp.encode(texts, out_type=int) y_tgt = sp_tgt.encode(tgt_texts, out_type=int) y = k2.RaggedTensor(y).to(device) @@ -774,7 +768,6 @@ def train_one_epoch( model: Union[nn.Module, DDP], optimizer: torch.optim.Optimizer, scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, sp_tgt: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, @@ -834,7 +827,6 @@ def train_one_epoch( loss, loss_info, inf_flag = compute_loss( params=params, model=model, - sp=sp, sp_tgt=sp_tgt, batch=batch, is_training=True, @@ -927,7 +919,6 @@ def train_one_epoch( valid_info = compute_validation_loss( params=params, model=model, - sp=sp, sp_tgt=sp_tgt, valid_dl=valid_dl, world_size=world_size, @@ -1007,9 +998,7 @@ def run(rank, world_size, args): device = torch.device("cuda", rank) logging.info(f"Device: {device}") - sp = spm.SentencePieceProcessor() sp_tgt = spm.SentencePieceProcessor() - sp.load(params.bpe_model) sp_tgt.load(params.bpe_tgt_model) # pdb.set_trace() # is defined in local/train_bpe_model.py @@ -1082,7 +1071,7 @@ def run(rank, world_size, args): # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold - if c.duration < 0.2 or c.duration > 30.0: + if c.duration < 0.1 or c.duration > 30.0: #logging.warning( # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" #) @@ -1139,7 +1128,6 @@ def run(rank, world_size, args): model=model, train_dl=train_dl, optimizer=optimizer, - sp=sp, sp_tgt=sp_tgt, params=params, warmup=0.0 if params.start_epoch == 1 else 1.0, @@ -1167,7 +1155,6 @@ def run(rank, world_size, args): model_avg=model_avg, optimizer=optimizer, scheduler=scheduler, - sp=sp, sp_tgt=sp_tgt, train_dl=train_dl, valid_dl=valid_dl, diff --git a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/train_analysis.py b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/train_analysis.py deleted file mode 100755 index 1e7acfe97..000000000 --- a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/train_analysis.py +++ /dev/null @@ -1,1202 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Johns Hopkins (authors: Amir Hussein) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./pruned_transducer_stateless5/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --exp-dir pruned_transducer_stateless5/exp \ - --full-libri 1 \ - --max-duration 300 - -# For mix precision training: - -./pruned_transducer_stateless5/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir pruned_transducer_stateless5/exp \ - --full-libri 1 \ - --max-duration 550 - -""" - -# xxx -import argparse -import copy -import logging -import pdb -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import nvidia_smi -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import MGB2AsrDataModule -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import Transducer -from optim import Eden, Eve -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils import clip_grad_norm_ -from torch.utils.tensorboard import SummaryWriter - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=int, - default=12, - help="Number of conformer encoder layers..", - ) - - parser.add_argument( - "--dim-feedforward", - type=int, - default=2048, - help="Feedforward dimension of the conformer encoder layer.", - ) - - parser.add_argument( - "--nhead", - type=int, - default=8, - help="Number of attention heads in the conformer encoder layer.", - ) - - parser.add_argument( - "--encoder-dim", - type=int, - default=512, - help="Attention dimension in the conformer encoder layer.", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless5/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_2000/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--initial-lr", - type=float, - default=0.003, - help="The initial learning rate. This value should not need " - "to be changed.", - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=6, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=8000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=10, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=100, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=True, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warm_step for Noam optimizer. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for conformer - "feature_dim": 80, - "subsampling_factor": 4, - # parameters for Noam - "model_warm_step": 80000, # arg given to model, not for lrate - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Conformer and Transformer - encoder = Conformer( - num_features=params.feature_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.encoder_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=params.encoder_dim, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=params.encoder_dim, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - if "cur_batch_idx" in saved_params: - params["cur_batch_idx"] = saved_params["cur_batch_idx"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, - warmup: float = 1.0, - reduction="none", -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute CTC loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Conformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - warmup=warmup, - reduction="none", - ) - simple_loss_is_finite = torch.isfinite(simple_loss) - pruned_loss_is_finite = torch.isfinite(pruned_loss) - is_finite = simple_loss_is_finite & pruned_loss_is_finite - inf_flag = False - if not torch.all(is_finite): - inf_flag = True - logging.info( - "Not all losses are finite!\n" - f"simple_loss: {simple_loss}\n" - f"pruned_loss: {pruned_loss}" - ) - display_and_save_batch(batch, params=params, sp=sp) - simple_loss = simple_loss[simple_loss_is_finite] - pruned_loss = pruned_loss[pruned_loss_is_finite] - - simple_loss = simple_loss.sum() - pruned_loss = pruned_loss.sum() - - # after the main warmup step, we keep pruned_loss_scale small - # for the same amount of time (model_warm_step), to avoid - # overwhelming the simple_loss and causing it to diverge, - # in case it had not fully learned the alignment yet. - pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss - ) - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) - - # # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa - # info["utterances"] = feature.size(0) - # # averaged input duration in frames over utterances - # info["utt_duration"] = feature_lens.sum().item() - # # averaged padding proportion over utterances - # info["utt_pad_proportion"] = ( - # ((feature.size(1) - feature_lens) / feature.size(1)).sum().item() - # ) - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info, inf_flag - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - with torch.no_grad(): - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info, inf_flag = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - cur_batch_idx = params.get("cur_batch_idx", 0) - - for batch_idx, batch in enumerate(train_dl): - - if batch["inputs"].shape[0] == len(batch["supervisions"]["text"]): - if batch_idx < cur_batch_idx: - continue - cur_batch_idx = batch_idx - - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info, inf_flag = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - warmup=( - params.batch_idx_train / params.model_warm_step - ), - ) - # summary stats - tot_loss = ( - tot_loss * (1 - 1 / params.reset_interval) - ) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - if not inf_flag: - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - else: - continue - except: # noqa - display_and_save_batch(batch, params=params, sp=sp) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - params.cur_batch_idx = batch_idx - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - del params.cur_batch_idx - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - # https://silpara.medium.com/check-gpu-memory-usage-from-python-ccca503322ea - memory_debugging() - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}" - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info( - f"Epoch {params.cur_epoch}, validation: {valid_info}" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - else: - logging.warning( - f"Batch {batch_idx} mismatch in dimentions between the input and the output. Skipping ..." - ) - continue - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def memory_debugging(): - # memory nvidia debugging - nvidia_smi.nvmlInit() - - deviceCount = nvidia_smi.nvmlDeviceGetCount() - for i in range(deviceCount): - handle = nvidia_smi.nvmlDeviceGetHandleByIndex(i) - info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle) - logging.info( - "Device {}: {}, Memory : ({:.2f}% free): {}(total), {} (free), {} (used)".format( - i, - nvidia_smi.nvmlDeviceGetName(handle), - 100 * info.free / info.total, - info.total, - info.free, - info.used, - ) - ) - - nvidia_smi.nvmlShutdown() - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - pdb.set_trace() - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model) - - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank]) - - optimizer = Eve(model.parameters(), lr=params.initial_lr) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - MGB2 = MGB2AsrDataModule(args) - train_cuts = MGB2.train_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 30 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - return 0.5 <= c.duration <= 30.0 - - def remove_short_and_long_text(c: Cut): - # Keep only text with charachters between 20 and 450 - - return 20 <= len(c.supervisions[0].text) <= 450 - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - train_cuts = train_cuts.filter(remove_short_and_long_text) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = MGB2.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = MGB2.dev_cuts() - valid_dl = MGB2.test_dataloaders(valid_cuts) - - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: spm.SentencePieceProcessor, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = sp.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - # warmup = 0.0 is so that the derivs for the pruned loss stay zero - # (i.e. are not remembered by the decaying-average in adam), because - # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): - - loss, _, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - warmup=0.0, - ) - loss.backward() - # clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, sp=sp) - raise - - -def main(): - parser = get_parser() - MGB2AsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/train_asr.py b/egs/iwslt22_ta/ST/pruned_transducer_stateless5/train_asr.py deleted file mode 100755 index 68f01dfac..000000000 --- a/egs/iwslt22_ta/ST/pruned_transducer_stateless5/train_asr.py +++ /dev/null @@ -1,1301 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Johns Hopkins (authors: Amir Hussein) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -./pruned_transducer_stateless5/train_asr.py \ - --world-size 4 \ - --num-epochs 20 \ - --start-epoch 1 \ - --exp-dir pruned_transducer_stateless5/exp_asr \ - --max-duration 150 \ - --bucketing-sampler 1\ - --num-buckets 50 - -""" - -# xxx -import argparse -import copy -import logging -import pdb -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import nvidia_smi -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import IWSLTDialectSTDataModule -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import Transducer -from optim import Eden, Eve -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils import clip_grad_norm_ -from torch.utils.tensorboard import SummaryWriter - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.utils import ( - AttributeDict, - MetricsTracker, - display_and_save_batch, - setup_logger, - str2bool, -) - -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=int, - default=12, - help="Number of conformer encoder layers..", - ) - - parser.add_argument( - "--dim-feedforward", - type=int, - default=1024, - help="Feedforward dimension of the conformer encoder layer.", - ) - - parser.add_argument( - "--nhead", - type=int, - default=8, - help="Number of attention heads in the conformer encoder layer.", - ) - - parser.add_argument( - "--encoder-dim", - type=int, - default=256, - help="Attention dimension in the conformer encoder layer.", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=256, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=256, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - parser.add_argument( - "--dynamic-chunk-training", - type=str2bool, - default=False, - help="""Whether to use dynamic_chunk_training, if you want a streaming - model, this requires to be True. - """, - ) - - parser.add_argument( - "--causal-convolution", - type=str2bool, - default=False, - help="""Whether to use causal convolution, this requires to be True when - using dynamic_chunk_training. - """, - ) - - parser.add_argument( - "--short-chunk-size", - type=int, - default=25, - help="""Chunk length of dynamic training, the chunk size would be either - max sequence length of current batch or uniformly sampled from (1, short_chunk_size). - """, - ) - - parser.add_argument( - "--num-left-chunks", - type=int, - default=4, - help="How many left context can be seen in chunks when calculating attention.", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless5/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_ta_1000/bpe.model", - help="Path to source data BPE model", - ) - parser.add_argument( - "--bpe-tgt-model", - type=str, - default="data/lang_bpe_en_1000/bpe.model", - help="Path to target data BPE model", - ) - parser.add_argument( - "--initial-lr", - type=float, - default=0.001, - help="The initial learning rate. This value should not need " - "to be changed.", - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=6, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=16000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=10, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=100, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=True, - help="Whether to use half precision training.", - ) - - parser.add_argument( - "--delay-penalty", - type=float, - default=0.0, - help="""A constant value used to penalize symbol delay, - to encourage streaming models to emit symbols earlier. - See https://github.com/k2-fsa/k2/issues/955 and - https://arxiv.org/pdf/2211.00490.pdf for more details.""", - ) - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warm_step for Noam optimizer. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 1000, # For the 100h subset, use 800 - # parameters for conformer - "feature_dim": 80, - "subsampling_factor": 4, - # parameters for Noam - "model_warm_step": 15000, # arg given to model, not for lrate - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Conformer and Transformer - encoder = Conformer( - num_features=params.feature_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.encoder_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - dynamic_chunk_training=params.dynamic_chunk_training, - short_chunk_size=params.short_chunk_size, - num_left_chunks=params.num_left_chunks, - causal=params.causal_convolution, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=params.encoder_dim, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=params.encoder_dim, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - if "cur_batch_idx" in saved_params: - params["cur_batch_idx"] = saved_params["cur_batch_idx"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - sp_tgt: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, - warmup: float = 1.0, - reduction="none", -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute CTC loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Conformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - #pdb.set_trace() - texts = batch["supervisions"]["text"] - tgt_texts = batch["supervisions"]["tgt_text"] - y = sp.encode(texts, out_type=int) - y_tgt = sp_tgt.encode(tgt_texts, out_type=int) - y = k2.RaggedTensor(y).to(device) - y_tgt = k2.RaggedTensor(y_tgt).to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - warmup=warmup, - reduction="none", - delay_penalty=params.delay_penalty if warmup >= 2.0 else 0, - ) - simple_loss_is_finite = torch.isfinite(simple_loss) - pruned_loss_is_finite = torch.isfinite(pruned_loss) - is_finite = simple_loss_is_finite & pruned_loss_is_finite - inf_flag = False - if not torch.all(is_finite): - inf_flag = True - logging.info( - "Not all losses are finite!\n" - f"simple_loss: {simple_loss}\n" - f"pruned_loss: {pruned_loss}" - ) - display_and_save_batch(batch, params=params, sp=sp) - simple_loss = simple_loss[simple_loss_is_finite] - pruned_loss = pruned_loss[pruned_loss_is_finite] - - simple_loss = simple_loss.sum() - pruned_loss = pruned_loss.sum() - - # after the main warmup step, we keep pruned_loss_scale small - # for the same amount of time (model_warm_step), to avoid - # overwhelming the simple_loss and causing it to diverge, - # in case it had not fully learned the alignment yet. - pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss - ) - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) - - # # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa - # info["utterances"] = feature.size(0) - # # averaged input duration in frames over utterances - # info["utt_duration"] = feature_lens.sum().item() - # # averaged padding proportion over utterances - # info["utt_pad_proportion"] = ( - # ((feature.size(1) - feature_lens) / feature.size(1)).sum().item() - # ) - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info, inf_flag - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - sp_tgt: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - with torch.no_grad(): - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info, inf_flag = compute_loss( - params=params, - model=model, - sp=sp, - sp_tgt=sp_tgt, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, - sp_tgt: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - cur_batch_idx = params.get("cur_batch_idx", 0) - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - if batch["inputs"].shape[0] == len(batch["supervisions"]["text"]): - if batch_idx < cur_batch_idx: - continue - cur_batch_idx = batch_idx - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info, inf_flag = compute_loss( - params=params, - model=model, - sp=sp, - sp_tgt=sp_tgt, - batch=batch, - is_training=True, - warmup=( - params.batch_idx_train / params.model_warm_step - ), - ) - # summary stats - tot_loss = ( - tot_loss * (1 - 1 / params.reset_interval) - ) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - if not inf_flag: - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - else: - continue - except: # noqa - display_and_save_batch(batch, params=params, sp=sp) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - params.cur_batch_idx = batch_idx - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - del params.cur_batch_idx - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - # https://silpara.medium.com/check-gpu-memory-usage-from-python-ccca503322ea - #memory_debugging() - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}" - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - sp_tgt=sp_tgt, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info( - f"Epoch {params.cur_epoch}, validation: {valid_info}" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - else: - logging.warning( - f"Batch {batch_idx} mismatch in dimentions between the input and the output. Skipping ..." - ) - continue - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def memory_debugging(): - # memory nvidia debugging - nvidia_smi.nvmlInit() - - deviceCount = nvidia_smi.nvmlDeviceGetCount() - for i in range(deviceCount): - handle = nvidia_smi.nvmlDeviceGetHandleByIndex(i) - info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle) - logging.info( - "Device {}: {}, Memory : ({:.2f}% free): {}(total), {} (free), {} (used)".format( - i, - nvidia_smi.nvmlDeviceGetName(handle), - 100 * info.free / info.total, - info.total, - info.free, - info.used, - ) - ) - - nvidia_smi.nvmlShutdown() - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp_tgt = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - sp_tgt.load(params.bpe_tgt_model) - # pdb.set_trace() - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model) - - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank]) - - optimizer = Eve(model.parameters(), lr=params.initial_lr) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - iwslt_ta = IWSLTDialectSTDataModule(args) - train_cuts = iwslt_ta.train_cuts() - # def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 30 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - # return 1 <= c.duration <= 28.0 - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 0.3 or c.duration > 30.0: - #logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - #) - return False - if c.supervisions == []: - return False - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./conformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 1) // 2 - 1) // 2 - tokens = sp.encode(c.supervisions[0].text, out_type=str) - - if T < len(tokens): - # logging.warning( - # f"Exclude cut with ID {c.id} from training. " - # f"Number of frames (before subsampling): {c.num_frames}. " - # f"Number of frames (after subsampling): {T}. " - # f"Text: {c.supervisions[0].text}. " - # f"Tokens: {tokens}. " - # f"Number of tokens: {len(tokens)}" - # ) - return False - - return True - - def remove_short_and_long_text(c: Cut): - # Keep only text with charachters between 20 and 400 - - return 3 <= len(c.supervisions[0].text) <= 400 - #logging.info(f"Total duration before filtering {train_cuts.describe()}") - train_cuts = train_cuts.filter(remove_short_and_long_utt) - train_cuts = train_cuts.filter(remove_short_and_long_text) - #logging.info(f"Total duration after filtering {train_cuts.describe()}") - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = iwslt_ta.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = iwslt_ta.dev_cuts() - valid_dl = iwslt_ta.test_dataloaders(valid_cuts) - - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - sp_tgt=sp_tgt, - params=params, - warmup=0.0 if params.start_epoch == 1 else 1.0, - ) - - scaler = GradScaler(enabled=params.use_fp16) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - sp_tgt=sp_tgt, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: spm.SentencePieceProcessor, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = sp.encode(supervisions["text"], out_type=int) - #y_tgt = sp_tgt.encode(supervisions["tgt_text"], out_type=int) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, - sp_tgt: spm.SentencePieceProcessor, - params: AttributeDict, - warmup: float, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - # warmup = 0.0 is so that the derivs for the pruned loss stay zero - # (i.e. are not remembered by the decaying-average in adam), because - # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): - - loss, _, _ = compute_loss( - params=params, - model=model, - sp=sp, - sp_tgt=sp_tgt, - batch=batch, - is_training=True, - warmup=warmup, - ) - loss.backward() - # clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, sp=sp) - raise - - -def main(): - parser = get_parser() - IWSLTDialectSTDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/iwslt22_ta/ST/shared b/egs/iwslt22_ta/ST/shared new file mode 120000 index 000000000..4cbd91a7e --- /dev/null +++ b/egs/iwslt22_ta/ST/shared @@ -0,0 +1 @@ +../../../icefall/shared \ No newline at end of file diff --git a/egs/iwslt22_ta/ST/zipformer/asr_datamodule.py b/egs/iwslt22_ta/ST/zipformer/asr_datamodule.py index dd1ba4a8d..b822b7c68 100644 --- a/egs/iwslt22_ta/ST/zipformer/asr_datamodule.py +++ b/egs/iwslt22_ta/ST/zipformer/asr_datamodule.py @@ -29,9 +29,9 @@ from lhotse.dataset import ( CutConcatenate, CutMix, DynamicBucketingSampler, - K2Speech2textTranslationDataset, + K2Speech2TextTranslationDataset, PrecomputedFeatures, - SingleCutSampler, + SimpleCutSampler, SpecAugment, ) from lhotse.dataset.input_strategies import OnTheFlyFeatures @@ -206,9 +206,7 @@ class IWSLTDialectSTDataModule: ) transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -256,7 +254,7 @@ class IWSLTDialectSTDataModule: logging.info("Disable SpecAugment") logging.info("About to create train dataset") - train = K2Speech2textTranslationDataset( + train = K2Speech2TextTranslationDataset( cut_transforms=transforms, input_transforms=input_transforms, return_cuts=self.args.return_cuts, @@ -273,7 +271,7 @@ class IWSLTDialectSTDataModule: # to be strict (e.g. could be randomized) # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa # Drop feats to be on the safe side. - train = K2Speech2textTranslationDataset( + train = K2Speech2TextTranslationDataset( cut_transforms=transforms, input_strategy=OnTheFlyFeatures( Fbank(FbankConfig(num_mel_bins=80)) @@ -292,8 +290,8 @@ class IWSLTDialectSTDataModule: drop_last=self.args.drop_last, ) else: - logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, @@ -330,14 +328,14 @@ class IWSLTDialectSTDataModule: logging.info("About to create dev dataset") if self.args.on_the_fly_feats: - validate = K2Speech2textTranslationDataset( + validate = K2Speech2TextTranslationDataset( cut_transforms=transforms, input_strategy=OnTheFlyFeatures( Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: - validate = K2Speech2textTranslationDataset( + validate = K2Speech2TextTranslationDataset( cut_transforms=transforms, return_cuts=self.args.return_cuts, ) @@ -359,7 +357,7 @@ class IWSLTDialectSTDataModule: def test_dataloaders(self, cuts: CutSet) -> DataLoader: logging.debug("About to create test dataset") - test = K2Speech2textTranslationDataset( + test = K2Speech2TextTranslationDataset( input_strategy=OnTheFlyFeatures( Fbank(FbankConfig(num_mel_bins=80))) if self.args.on_the_fly_feats diff --git a/egs/iwslt22_ta/ST/zipformer/decode_st.py b/egs/iwslt22_ta/ST/zipformer/decode.py similarity index 94% rename from egs/iwslt22_ta/ST/zipformer/decode_st.py rename to egs/iwslt22_ta/ST/zipformer/decode.py index e896fb587..265491f0c 100755 --- a/egs/iwslt22_ta/ST/zipformer/decode_st.py +++ b/egs/iwslt22_ta/ST/zipformer/decode.py @@ -19,77 +19,34 @@ Usage: (1) greedy search ./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ + --epoch 20 \ + --avg 13 \ --exp-dir ./zipformer/exp \ --max-duration 600 \ - --decoding-method greedy_search + --decoding-method greedy_search \ + --num-encoder-layers 2,2,2,2,2,2 \ + --feedforward-dim 512,768,1024,1536,1024,768 \ + --encoder-dim 192,256,384,512,384,256 \ + --encoder-unmasked-dim 192,192,256,256,256,192 \ + --context-size 2 \ + --use-averaged-model true \ + --use-hat-decode false -(2) beam search (not recommended) +(2) modified beam search ./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ + --epoch 20 \ + --avg 13 \ --exp-dir ./zipformer/exp \ --max-duration 600 \ --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search (one best) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(5) fast beam search (nbest) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(6) fast beam search (nbest oracle WER) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(7) fast beam search (with LG) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_LG \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 + --beam-size 20 \ + --num-encoder-layers 2,2,2,2,2,2 \ + --feedforward-dim 512,768,1024,1536,1024,768 \ + --encoder-dim 192,256,384,512,384,256 \ + --encoder-unmasked-dim 192,192,256,256,256,192 \ + --context-size 2 \ + --use-averaged-model true \ + --use-hat-decode false """ @@ -230,6 +187,12 @@ def get_parser(): `--lang-dir`, which should contain `LG.pt`. """, ) + parser.add_argument( + "--use-hat-decode", + type=str2bool, + default=False, + help="If True, use HAT loss.", + ) parser.add_argument( "--beam-size", @@ -462,6 +425,7 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, + use_hat=params.use_hat_decode ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -555,7 +519,7 @@ def decode_dataset( for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - texts_tgt = batch["supervisions"]["tgt_text"] + texts_tgt = batch["supervisions"]["tgt_text"]['eng'] hyps_dict = decode_one_batch( params=params, diff --git a/egs/iwslt22_ta/ST/zipformer/decode_asr.py b/egs/iwslt22_ta/ST/zipformer/decode_asr.py deleted file mode 100755 index 6f5ec49e6..000000000 --- a/egs/iwslt22_ta/ST/zipformer/decode_asr.py +++ /dev/null @@ -1,852 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2023 Johns Hopkins University (Author: Amir Hussein) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) beam search (not recommended) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search (one best) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(5) fast beam search (nbest) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(6) fast beam search (nbest oracle WER) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(7) fast beam search (with LG) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_LG \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 -""" - - -import argparse -import logging -import math -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import IWSLTDialectSTDataModule -from beam_search import ( - beam_search, - fast_beam_search_nbest, - fast_beam_search_nbest_LG, - fast_beam_search_nbest_oracle, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from train_asr import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - make_pad_mask, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_ta_1000/bpe.model", - help="Path to source data BPE model", - ) - parser.add_argument( - "--bpe-tgt-model", - type=str, - default="data/lang_bpe_en_1000/bpe.model", - help="Path to target data BPE model", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/ang_bpe_ta_1000", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--lang-tgt-dir", - type=Path, - default="data/lang_bpe_en_1000", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - - fast_beam_search_nbest - - fast_beam_search_nbest_oracle - - fast_beam_search_nbest_LG - If you use fast_beam_search_nbest_LG, you have to specify - `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, - fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - batch: dict, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - if params.causal: - # this seems to cause insertions at the end of the utterance if used with zipformer. - pad_len = 30 - feature_lens += pad_len - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, pad_len), - value=LOG_EPS, - ) - - x, x_lens = model.encoder_embed(feature, feature_lens) - - src_key_padding_mask = make_pad_mask(x_lens) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - encoder_out, encoder_out_lens = model.encoder( - x, x_lens, src_key_padding_mask - ) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "fast_beam_search_nbest_LG": - hyp_tokens = fast_beam_search_nbest_LG( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in hyp_tokens: - hyps.append([word_table[i] for i in hyp]) - elif params.decoding_method == "fast_beam_search_nbest": - hyp_tokens = fast_beam_search_nbest( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "fast_beam_search_nbest_oracle": - hyp_tokens = fast_beam_search_nbest_oracle( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - ref_texts=sp.encode(supervisions["text"]), - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - use_hat=True, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp).split()) - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif "fast_beam_search" in params.decoding_method: - key = f"beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" - if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - - return {key: hyps} - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - word_table=word_table, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - IWSLTDialectSTDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - # use predefined parameters that were used during the training - # params.num_encoder_layers = "2,2,2,2,2,2" - # params.feedforward_dim = "256,512,768,1024,768,512" - # params.encoder_dim = "128,256,256,512,256,256" - # params.encoder_unmasked_dim = "64,128,128,256,128,128" - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "fast_beam_search_nbest", - "fast_beam_search_nbest_LG", - "fast_beam_search_nbest_oracle", - "modified_beam_search", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if params.causal: - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - if "nbest" in params.decoding_method: - params.suffix += f"-nbest-scale-{params.nbest_scale}" - params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": - lexicon = Lexicon(params.lang_dir) - word_table = lexicon.word_table - lg_filename = params.lang_dir / "LG.pt" - logging.info(f"Loading {lg_filename}") - decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) - decoding_graph.scores *= params.ngram_lm_scale - else: - word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) - else: - decoding_graph = None - word_table = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - iwslt_ta = IWSLTDialectSTDataModule(args) - - test_cuts = iwslt_ta.test_cuts() - dev_cuts = iwslt_ta.dev_cuts() - test_dl = iwslt_ta.test_dataloaders(test_cuts) - dev_dl = iwslt_ta.test_dataloaders(dev_cuts) - - test_sets = ["test", "dev"] - test_all_dl = [test_dl, dev_dl] - - for test_set, test_dl in zip(test_sets, test_all_dl): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - sp=sp, - word_table=word_table, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/iwslt22_ta/ST/zipformer/decoder.py b/egs/iwslt22_ta/ST/zipformer/decoder.py index 45432d570..0ca06233a 100644 --- a/egs/iwslt22_ta/ST/zipformer/decoder.py +++ b/egs/iwslt22_ta/ST/zipformer/decoder.py @@ -62,10 +62,15 @@ class Decoder(nn.Module): ) # the balancers are to avoid any drift in the magnitude of the # embeddings, which would interact badly with parameter averaging. - self.balancer = Balancer(decoder_dim, channel_dim=-1, - min_positive=0.0, max_positive=1.0, - min_abs=0.5, max_abs=1.0, - prob=0.05) + self.balancer = Balancer( + decoder_dim, + channel_dim=-1, + min_positive=0.0, + max_positive=1.0, + min_abs=0.5, + max_abs=1.0, + prob=0.05, + ) self.blank_id = blank_id @@ -82,10 +87,15 @@ class Decoder(nn.Module): groups=decoder_dim // 4, # group size == 4 bias=False, ) - self.balancer2 = Balancer(decoder_dim, channel_dim=-1, - min_positive=0.0, max_positive=1.0, - min_abs=0.5, max_abs=1.0, - prob=0.05) + self.balancer2 = Balancer( + decoder_dim, + channel_dim=-1, + min_positive=0.0, + max_positive=1.0, + min_abs=0.5, + max_abs=1.0, + prob=0.05, + ) def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: """ @@ -108,9 +118,7 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/iwslt22_ta/ST/zipformer/encoder_interface.py b/egs/iwslt22_ta/ST/zipformer/encoder_interface.py deleted file mode 100644 index 257facce4..000000000 --- a/egs/iwslt22_ta/ST/zipformer/encoder_interface.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Tuple - -import torch -import torch.nn as nn - - -class EncoderInterface(nn.Module): - def forward( - self, x: torch.Tensor, x_lens: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - x: - A tensor of shape (batch_size, input_seq_len, num_features) - containing the input features. - x_lens: - A tensor of shape (batch_size,) containing the number of frames - in `x` before padding. - Returns: - Return a tuple containing two tensors: - - encoder_out, a tensor of (batch_size, out_seq_len, output_dim) - containing unnormalized probabilities, i.e., the output of a - linear layer. - - encoder_out_lens, a tensor of shape (batch_size,) containing - the number of frames in `encoder_out` before padding. - """ - raise NotImplementedError("Please implement it in a subclass") diff --git a/egs/iwslt22_ta/ST/zipformer/encoder_interface.py b/egs/iwslt22_ta/ST/zipformer/encoder_interface.py new file mode 120000 index 000000000..c2eaca671 --- /dev/null +++ b/egs/iwslt22_ta/ST/zipformer/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/encoder_interface.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ST/zipformer/joiner.py b/egs/iwslt22_ta/ST/zipformer/joiner.py deleted file mode 100644 index f03cc930e..000000000 --- a/egs/iwslt22_ta/ST/zipformer/joiner.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn as nn -from scaling import ScaledLinear - - -class Joiner(nn.Module): - def __init__( - self, - encoder_dim: int, - decoder_dim: int, - joiner_dim: int, - vocab_size: int, - ): - super().__init__() - - self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25) - self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25) - self.output_linear = nn.Linear(joiner_dim, vocab_size) - - def forward( - self, - encoder_out: torch.Tensor, - decoder_out: torch.Tensor, - project_input: bool = True, - ) -> torch.Tensor: - """ - Args: - encoder_out: - Output from the encoder. Its shape is (N, T, s_range, C). - decoder_out: - Output from the decoder. Its shape is (N, T, s_range, C). - project_input: - If true, apply input projections encoder_proj and decoder_proj. - If this is false, it is the user's responsibility to do this - manually. - Returns: - Return a tensor of shape (N, T, s_range, C). - """ - assert encoder_out.ndim == decoder_out.ndim, (encoder_out.shape, decoder_out.shape) - - if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj( - decoder_out - ) - else: - logit = encoder_out + decoder_out - - logit = self.output_linear(torch.tanh(logit)) - - return logit diff --git a/egs/iwslt22_ta/ST/zipformer/joiner.py b/egs/iwslt22_ta/ST/zipformer/joiner.py new file mode 120000 index 000000000..5b8a36332 --- /dev/null +++ b/egs/iwslt22_ta/ST/zipformer/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ST/zipformer/model.py b/egs/iwslt22_ta/ST/zipformer/model.py index 44d1dca59..05bed9364 100644 --- a/egs/iwslt22_ta/ST/zipformer/model.py +++ b/egs/iwslt22_ta/ST/zipformer/model.py @@ -39,6 +39,7 @@ class Transducer(nn.Module): decoder_dim: int, joiner_dim: int, vocab_size: int, + use_hat: bool = False, ): """ Args: @@ -68,6 +69,7 @@ class Transducer(nn.Module): self.encoder = encoder self.decoder = decoder self.joiner = joiner + self.use_hat = use_hat self.simple_am_proj = ScaledLinear( encoder_dim, @@ -213,7 +215,7 @@ class Transducer(nn.Module): termination_symbol=blank_id, boundary=boundary, reduction="sum", - use_hat_loss=True, + use_hat_loss=self.use_hat, ) return (simple_loss, pruned_loss) diff --git a/egs/iwslt22_ta/ST/zipformer/optim.py b/egs/iwslt22_ta/ST/zipformer/optim.py deleted file mode 100644 index abfb2092c..000000000 --- a/egs/iwslt22_ta/ST/zipformer/optim.py +++ /dev/null @@ -1,1173 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) -# -# See ../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import contextlib -import logging -import random -from collections import defaultdict -from typing import Dict, List, Optional, Tuple, Union - -import torch -from lhotse.utils import fix_random_seed -from torch import Tensor -from torch.optim import Optimizer - - -class BatchedOptimizer(Optimizer): - """ - This class adds to class Optimizer the capability to optimize parameters in batches: - it will stack the parameters and their grads for you so the optimizer can work - on tensors with an extra leading dimension. This is intended for speed with GPUs, - as it reduces the number of kernels launched in the optimizer. - - Args: - params: - """ - - def __init__(self, params, defaults): - super(BatchedOptimizer, self).__init__(params, defaults) - - @contextlib.contextmanager - def batched_params(self, param_group, group_params_names): - """ - This function returns (technically, yields) a list of - of tuples (p, state), where - p is a `fake` parameter that is stacked (over axis 0) from real parameters - that share the same shape, and its gradient is also stacked; - `state` is the state corresponding to this batch of parameters - (it will be physically located in the "state" for one of the real - parameters, the last one that has any particular shape and dtype). - - This function is decorated as a context manager so that it can - write parameters back to their "real" locations. - - The idea is, instead of doing: - - for p in group["params"]: - state = self.state[p] - ... - - you can do: - - with self.batched_params(group["params"]) as batches: - for p, state, p_names in batches: - ... - - - Args: - group: a parameter group, which is a list of parameters; should be - one of self.param_groups. - group_params_names: name for each parameter in group, - which is List[str]. - """ - batches = defaultdict( - list - ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter - batches_names = defaultdict( - list - ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str - - assert len(param_group) == len(group_params_names) - for p, named_p in zip(param_group, group_params_names): - key = (str(p.dtype), *p.shape) - batches[key].append(p) - batches_names[key].append(named_p) - - batches_names_keys = list(batches_names.keys()) - sorted_idx = sorted( - range(len(batches_names)), key=lambda i: batches_names_keys[i] - ) - batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx] - batches = [batches[batches_names_keys[idx]] for idx in sorted_idx] - - stacked_params_dict = dict() - - # turn batches into a list, in deterministic order. - # tuples will contain tuples of (stacked_param, state, stacked_params_names), - # one for each batch in `batches`. - tuples = [] - - for batch, batch_names in zip(batches, batches_names): - p = batch[0] - # we arbitrarily store the state in the - # state corresponding to the 1st parameter in the - # group. class Optimizer will take care of saving/loading state. - state = self.state[p] - p_stacked = torch.stack(batch) - grad = torch.stack( - [torch.zeros_like(p) if p.grad is None else p.grad for p in batch] - ) - p_stacked.grad = grad - stacked_params_dict[key] = p_stacked - tuples.append((p_stacked, state, batch_names)) - - yield tuples # <-- calling code will do the actual optimization here! - - for ((stacked_params, _state, _names), batch) in zip(tuples, batches): - for i, p in enumerate(batch): # batch is list of Parameter - p.copy_(stacked_params[i]) - - -class ScaledAdam(BatchedOptimizer): - """ - Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update - proportional to the norm of that parameter; and also learn the scale of the parameter, - in log space, subject to upper and lower limits (as if we had factored each parameter as - param = underlying_param * log_scale.exp()) - - - Args: - params: The parameters or param_groups to optimize (like other Optimizer subclasses) - Unlike common optimizers, which accept model.parameters() or groups of parameters(), - this optimizer could accept model.named_parameters() or groups of named_parameters(). - See comments of function _get_names_of_parameters for its 4 possible cases. - lr: The learning rate. We will typically use a learning rate schedule that starts - at 0.03 and decreases over time, i.e. much higher than other common - optimizers. - clipping_scale: (e.g. 2.0) - A scale for gradient-clipping: if specified, the normalized gradients - over the whole model will be clipped to have 2-norm equal to - `clipping_scale` times the median 2-norm over the most recent period - of `clipping_update_period` minibatches. By "normalized gradients", - we mean after multiplying by the rms parameter value for this tensor - [for non-scalars]; this is appropriate because our update is scaled - by this quantity. - betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. - Must satisfy 0 < beta <= beta2 < 1. - scalar_lr_scale: A scaling factor on the learning rate, that we use to update the - scale of each parameter tensor and scalar parameters of the mode.. - If each parameter were decomposed - as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale - would be a the scaling factor on the learning rate of p_scale. - eps: A general-purpose epsilon to prevent division by zero - param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be >= this value) - param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be <= this value) - scalar_max: Maximum absolute value for scalar parameters (applicable if your - model has any parameters with numel() == 1). - size_update_period: The periodicity, in steps, with which we update the size (scale) - of the parameter tensor. This is provided to save a little time - in the update. - clipping_update_period: if clipping_scale is specified, this is the period - """ - - def __init__( - self, - params, - lr=3e-02, - clipping_scale=None, - betas=(0.9, 0.98), - scalar_lr_scale=0.1, - eps=1.0e-08, - param_min_rms=1.0e-05, - param_max_rms=3.0, - scalar_max=10.0, - size_update_period=4, - clipping_update_period=100, - ): - - defaults = dict( - lr=lr, - clipping_scale=clipping_scale, - betas=betas, - scalar_lr_scale=scalar_lr_scale, - eps=eps, - param_min_rms=param_min_rms, - param_max_rms=param_max_rms, - scalar_max=scalar_max, - size_update_period=size_update_period, - clipping_update_period=clipping_update_period, - ) - - # If params only contains parameters or group of parameters, - # i.e when parameter names are not given, - # this flag will be set to False in funciton _get_names_of_parameters. - self.show_dominant_parameters = True - param_groups, parameters_names = self._get_names_of_parameters(params) - super(ScaledAdam, self).__init__(param_groups, defaults) - assert len(self.param_groups) == len(parameters_names) - self.parameters_names = parameters_names - - def _get_names_of_parameters( - self, params_or_named_params - ) -> Tuple[List[Dict], List[List[str]]]: - """ - Args: - params_or_named_params: according to the way ScaledAdam is initialized in train.py, - this argument could be one of following 4 cases, - case 1, a generator of parameter, e.g.: - optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=3.0) - - case 2, a list of parameter groups with different config, e.g.: - model_param_groups = [ - {'params': model.encoder.parameters(), 'lr': 0.05}, - {'params': model.decoder.parameters(), 'lr': 0.01}, - {'params': model.joiner.parameters(), 'lr': 0.03}, - ] - optimizer = ScaledAdam(model_param_groups, lr=params.base_lr, clipping_scale=3.0) - - case 3, a generator of named_parameter, e.g.: - optimizer = ScaledAdam(model.named_parameters(), lr=params.base_lr, clipping_scale=3.0) - - case 4, a list of named_parameter groups with different config, e.g.: - model_named_param_groups = [ - {'named_params': model.encoder.named_parameters(), 'lr': 0.05}, - {'named_params': model.decoder.named_parameters(), 'lr': 0.01}, - {'named_params': model.joiner.named_parameters(), 'lr': 0.03}, - ] - optimizer = ScaledAdam(model_named_param_groups, lr=params.base_lr, clipping_scale=3.0) - - For case 1 and case 2, input params is used to initialize the underlying torch.optimizer. - For case 3 and case 4, firstly, names and params are extracted from input named_params, - then, these extracted params are used to initialize the underlying torch.optimizer, - and these extracted names are mainly used by function - `_show_gradient_dominating_parameter` - - Returns: - Returns a tuple containing 2 elements: - - `param_groups` with type List[Dict], each Dict element is a parameter group. - An example of `param_groups` could be: - [ - {'params': `one iterable of Parameter`, 'lr': 0.05}, - {'params': `another iterable of Parameter`, 'lr': 0.08}, - {'params': `a third iterable of Parameter`, 'lr': 0.1}, - ] - - `param_gruops_names` with type List[List[str]], - each `List[str]` is for a group['params'] in param_groups, - and each `str` is the name of a parameter. - A dummy name "foo" is related to each parameter, - if input are params without names, i.e. case 1 or case 2. - """ - # variable naming convention in this function: - # p is short for param. - # np is short for named_param. - # p_or_np is short for param_or_named_param. - # cur is short for current. - # group is a dict, e.g. {'params': iterable of parameter, 'lr': 0.05, other fields}. - # groups is a List[group] - - iterable_or_groups = list(params_or_named_params) - if len(iterable_or_groups) == 0: - raise ValueError("optimizer got an empty parameter list") - - # The first value of returned tuple. A list of dicts containing at - # least 'params' as a key. - param_groups = [] - - # The second value of returned tuple, - # a List[List[str]], each sub-List is for a group. - param_groups_names = [] - - if not isinstance(iterable_or_groups[0], dict): - # case 1 or case 3, - # the input is an iterable of parameter or named parameter. - param_iterable_cur_group = [] - param_names_cur_group = [] - for p_or_np in iterable_or_groups: - if isinstance(p_or_np, tuple): - # case 3 - name, param = p_or_np - else: - # case 1 - assert isinstance(p_or_np, torch.Tensor) - param = p_or_np - # Assign a dummy name as a placeholder - name = "foo" - self.show_dominant_parameters = False - param_iterable_cur_group.append(param) - param_names_cur_group.append(name) - param_groups.append({"params": param_iterable_cur_group}) - param_groups_names.append(param_names_cur_group) - else: - # case 2 or case 4 - # the input is groups of parameter or named parameter. - for cur_group in iterable_or_groups: - assert "named_params" in cur_group - name_list = [ x[0] for x in cur_group["named_params"] ] - p_list = [ x[1] for x in cur_group["named_params"] ] - del cur_group["named_params"] - cur_group["params"] = p_list - param_groups.append(cur_group) - param_groups_names.append(name_list) - - return param_groups, param_groups_names - - def __setstate__(self, state): - super(ScaledAdam, self).__setstate__(state) - - @torch.no_grad() - def step(self, closure=None): - """Performs a single optimization step. - - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - batch = True - - for group, group_params_names in zip(self.param_groups, self.parameters_names): - - with self.batched_params(group["params"], group_params_names) as batches: - - # batches is list of pairs (stacked_param, state). stacked_param is like - # a regular parameter, and will have a .grad, but the 1st dim corresponds to - # a stacking dim, it is not a real dim. - - if ( - len(batches[0][1]) == 0 - ): # if len(first state) == 0: not yet initialized - clipping_scale = 1 - else: - clipping_scale = self._get_clipping_scale(group, batches) - - for p, state, _ in batches: - # Perform optimization step. - # grad is not going to be None, we handled that when creating the batches. - grad = p.grad - if grad.is_sparse: - raise RuntimeError( - "ScaledAdam optimizer does not support sparse gradients" - ) - # State initialization - if len(state) == 0: - self._init_state(group, p, state) - - self._step_one_batch(group, p, state, clipping_scale) - - return loss - - def _init_state(self, group: dict, p: Tensor, state: dict): - """ - Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p - is actually the batch dimension, corresponding to batched-together - parameters of a given shape. - - - Args: - group: Dict to look up configuration values. - p: The parameter that we are initializing the state for - state: Dict from string to whatever state we are initializing - """ - size_update_period = group["size_update_period"] - - state["step"] = 0 - - kwargs = {"device": p.device, "dtype": p.dtype} - - # 'delta' implements conventional momentum. There are - # several different kinds of update going on, so rather than - # compute "exp_avg" like in Adam, we store and decay a - # parameter-change "delta", which combines all forms of - # update. this is equivalent to how it's done in Adam, - # except for the first few steps. - state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format) - - batch_size = p.shape[0] - numel = p.numel() // batch_size - - if numel > 1: - # "param_rms" just periodically records the scalar root-mean-square value of - # the parameter tensor. - # it has a shape like (batch_size, 1, 1, 1, 1) - param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() - state["param_rms"] = param_rms - - state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) - state["scale_grads"] = torch.zeros( - size_update_period, *param_rms.shape, **kwargs - ) - - # exp_avg_sq is the weighted sum of scaled gradients. as in Adam. - state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) - - def _get_clipping_scale( - self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]] - ) -> float: - """ - Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients - by this amount before applying the rest of the update. - - Args: - group: the parameter group, an item in self.param_groups - tuples: a list of tuples of (param, state, param_names) - where param is a batched set of parameters, - with a .grad (1st dim is batch dim) - and state is the state-dict where optimization parameters are kept. - param_names is a List[str] while each str is name for a parameter - in batched set of parameters "param". - """ - assert len(tuples) >= 1 - clipping_scale = group["clipping_scale"] - (first_p, first_state, _) = tuples[0] - step = first_state["step"] - if clipping_scale is None or step == 0: - # no clipping. return early on step == 0 because the other - # parameters' state won't have been initialized yet. - return 1.0 - clipping_update_period = group["clipping_update_period"] - - tot_sumsq = torch.tensor(0.0, device=first_p.device) - for (p, state, param_names) in tuples: - grad = p.grad - if grad.is_sparse: - raise RuntimeError( - "ScaledAdam optimizer does not support sparse gradients" - ) - if p.numel() == p.shape[0]: # a batch of scalars - tot_sumsq += (grad**2).sum() # sum() to change shape [1] to [] - else: - tot_sumsq += ((grad * state["param_rms"]) ** 2).sum() - - tot_norm = tot_sumsq.sqrt() - if "model_norms" not in first_state: - first_state["model_norms"] = torch.zeros( - clipping_update_period, device=p.device - ) - first_state["model_norms"][step % clipping_update_period] = tot_norm - - if step % clipping_update_period == 0: - # Print some stats. - # We don't reach here if step == 0 because we would have returned - # above. - sorted_norms = first_state["model_norms"].sort()[0].to("cpu") - quartiles = [] - for n in range(0, 5): - index = min( - clipping_update_period - 1, (clipping_update_period // 4) * n - ) - quartiles.append(sorted_norms[index].item()) - - median = quartiles[2] - threshold = clipping_scale * median - first_state["model_norm_threshold"] = threshold - percent_clipped = ( - first_state["num_clipped"] * 100.0 / clipping_update_period - if "num_clipped" in first_state - else 0.0 - ) - first_state["num_clipped"] = 0 - quartiles = " ".join(["%.3e" % x for x in quartiles]) - logging.info( - f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " - f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}" - ) - - if step < clipping_update_period: - return 1.0 # We have not yet estimated a norm to clip to. - else: - try: - model_norm_threshold = first_state["model_norm_threshold"] - except KeyError: - logging.info( - "Warning: model_norm_threshold not in state: possibly " - "you changed config when restarting, adding clipping_scale option?" - ) - return 1.0 - ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item()) - if ans < 1.0: - first_state["num_clipped"] += 1 - if ans < 0.1: - logging.warn( - f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}" - ) - if self.show_dominant_parameters: - assert p.shape[0] == len(param_names) - self._show_gradient_dominating_parameter(tuples, tot_sumsq) - return ans - - def _show_gradient_dominating_parameter( - self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor - ): - """ - Show information of parameter which dominates tot_sumsq. - - Args: - tuples: a list of tuples of (param, state, param_names) - where param is a batched set of parameters, - with a .grad (1st dim is batch dim) - and state is the state-dict where optimization parameters are kept. - param_names is a List[str] while each str is name for a parameter - in batched set of parameters "param". - tot_sumsq: sumsq of all parameters. Though it's could be calculated - from tuples, we still pass it to save some time. - """ - all_sumsq_orig = {} - for (p, state, batch_param_names) in tuples: - # p is a stacked batch parameters. - batch_grad = p.grad - if p.numel() == p.shape[0]: # a batch of scalars - batch_sumsq_orig = batch_grad**2 - # Dummy values used by following `zip` statement. - batch_rms_orig = torch.ones(p.shape[0]) - else: - batch_rms_orig = state["param_rms"] - batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum( - dim=list(range(1, batch_grad.ndim)) - ) - - for name, sumsq_orig, rms, grad in zip( - batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad - ): - - proportion_orig = sumsq_orig / tot_sumsq - all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad) - - assert torch.isclose( - sum([value[0] for value in all_sumsq_orig.values()]).cpu(), - torch.tensor(1.0), - ) - sorted_by_proportion = { - k: v - for k, v in sorted( - all_sumsq_orig.items(), key=lambda item: item[1][0], reverse=True - ) - } - dominant_param_name = next(iter(sorted_by_proportion)) - ( - dominant_proportion, - dominant_sumsq, - dominant_rms, - dominant_grad, - ) = sorted_by_proportion[dominant_param_name] - logging.info( - f"Parameter dominating tot_sumsq {dominant_param_name}" - f" with proportion {dominant_proportion:.2f}," - f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)" - f"={dominant_sumsq:.3e}," - f" grad_sumsq={(dominant_grad**2).sum():.3e}," - f" orig_rms_sq={(dominant_rms**2).item():.3e}" - ) - - def _step_one_batch( - self, group: dict, p: Tensor, state: dict, clipping_scale: float - ): - """ - Do the step for one parameter, which is actually going to be a batch of - `real` parameters, with dim 0 as the batch dim. - Args: - group: dict to look up configuration values - p: parameter to update (actually multiple parameters stacked together - as a batch) - state: state-dict for p, to look up the optimizer state - """ - lr = group["lr"] - size_update_period = group["size_update_period"] - beta1 = group["betas"][0] - - grad = p.grad - if clipping_scale != 1.0: - grad = grad * clipping_scale - step = state["step"] - delta = state["delta"] - - delta.mul_(beta1) - batch_size = p.shape[0] - numel = p.numel() // batch_size - if numel > 1: - # Update the size/scale of p, and set param_rms - scale_grads = state["scale_grads"] - scale_grads[step % size_update_period] = (p * grad).sum( - dim=list(range(1, p.ndim)), keepdim=True - ) - if step % size_update_period == size_update_period - 1: - param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..) - param_rms.copy_( - (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() - ) - if step > 0: - # self._size_update() learns the overall scale on the - # parameter, by shrinking or expanding it. - self._size_update(group, scale_grads, p, state) - - if numel == 1: - # For parameters with 1 element we just use regular Adam. - # Updates delta. - self._step_scalar(group, p, state) - else: - self._step(group, p, state) - - state["step"] = step + 1 - - def _size_update( - self, group: dict, scale_grads: Tensor, p: Tensor, state: dict - ) -> None: - """ - Called only where p.numel() > 1, this updates the scale of the parameter. - If we imagine: p = underlying_param * scale.exp(), and we are doing - gradient descent on underlying param and on scale, this function does the update - on `scale`. - - Args: - group: dict to look up configuration values - scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing - grads w.r.t. the scales. - p: The parameter to update - state: The state-dict of p - """ - - param_rms = state["param_rms"] - beta1, beta2 = group["betas"] - size_lr = group["lr"] * group["scalar_lr_scale"] - param_min_rms = group["param_min_rms"] - param_max_rms = group["param_max_rms"] - eps = group["eps"] - step = state["step"] - batch_size = p.shape[0] - - size_update_period = scale_grads.shape[0] - # correct beta2 for the size update period: we will have - # faster decay at this level. - beta2_corr = beta2**size_update_period - - scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..) - scale_exp_avg_sq.mul_(beta2_corr).add_( - (scale_grads**2).mean(dim=0), # mean over dim `size_update_period` - alpha=1 - beta2_corr, - ) # shape is (batch_size, 1, 1, ...) - - # The 1st time we reach here is when size_step == 1. - size_step = (step + 1) // size_update_period - bias_correction2 = 1 - beta2_corr**size_step - # we don't bother with bias_correction1; this will help prevent divergence - # at the start of training. - - denom = scale_exp_avg_sq.sqrt() + eps - - scale_step = ( - -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom - ) - - is_too_small = param_rms < param_min_rms - - # when the param gets too small, just don't shrink it any further. - scale_step.masked_fill_(is_too_small, 0.0) - - # and ensure the parameter rms after update never exceeds param_max_rms. - # We have to look at the trained model for parameters at or around the - # param_max_rms, because sometimes they can indicate a problem with the - # topology or settings. - scale_step = torch.minimum(scale_step, - (param_max_rms - param_rms) / param_rms) - - delta = state["delta"] - # the factor of (1-beta1) relates to momentum. - delta.add_(p * scale_step, alpha=(1 - beta1)) - - def _step(self, group: dict, p: Tensor, state: dict): - """ - This function does the core update of self.step(), in the case where the members of - the batch have more than 1 element. - - Args: - group: A dict which will be used to look up configuration values - p: The parameter to be updated - grad: The grad of p - state: The state-dict corresponding to parameter p - - This function modifies p. - """ - grad = p.grad - lr = group["lr"] - beta1, beta2 = group["betas"] - eps = group["eps"] - param_min_rms = group["param_min_rms"] - step = state["step"] - - exp_avg_sq = state["exp_avg_sq"] - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) - - this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0) - bias_correction2 = 1 - beta2 ** (this_step + 1) - if bias_correction2 < 0.99: - # note: not in-place. - exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2) - - denom = exp_avg_sq.sqrt() - denom += eps - grad = grad / denom - - alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms) - - delta = state["delta"] - delta.add_(grad * alpha) - p.add_(delta) - - def _step_scalar(self, group: dict, p: Tensor, state: dict): - """ - A simplified form of the core update for scalar tensors, where we cannot get a good - estimate of the parameter rms. - """ - beta1, beta2 = group["betas"] - scalar_max = group["scalar_max"] - eps = group["eps"] - lr = group["lr"] * group["scalar_lr_scale"] - grad = p.grad - - exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - - # bias_correction2 is like in Adam. Don't bother with bias_correction1; - # slower update at the start will help stability anyway. - bias_correction2 = 1 - beta2 ** (state["step"] + 1) - denom = (exp_avg_sq / bias_correction2).sqrt() + eps - - delta = state["delta"] - delta.add_(grad / denom, alpha=-lr * (1 - beta1)) - p.clamp_(min=-scalar_max, max=scalar_max) - p.add_(delta) - - -class LRScheduler(object): - """ - Base-class for learning rate schedulers where the learning-rate depends on both the - batch and the epoch. - """ - - def __init__(self, optimizer: Optimizer, verbose: bool = False): - # Attach optimizer - if not isinstance(optimizer, Optimizer): - raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) - self.optimizer = optimizer - self.verbose = verbose - - for group in optimizer.param_groups: - group.setdefault("base_lr", group["lr"]) - - self.base_lrs = [group["base_lr"] for group in optimizer.param_groups] - - self.epoch = 0 - self.batch = 0 - - def state_dict(self): - """Returns the state of the scheduler as a :class:`dict`. - - It contains an entry for every variable in self.__dict__ which - is not the optimizer. - """ - return { - "base_lrs": self.base_lrs, - "epoch": self.epoch, - "batch": self.batch, - } - - def load_state_dict(self, state_dict): - """Loads the schedulers state. - - Args: - state_dict (dict): scheduler state. Should be an object returned - from a call to :meth:`state_dict`. - """ - self.__dict__.update(state_dict) - - def get_last_lr(self) -> List[float]: - """Return last computed learning rate by current scheduler. Will be a list of float.""" - return self._last_lr - - def get_lr(self): - # Compute list of learning rates from self.epoch and self.batch and - # self.base_lrs; this must be overloaded by the user. - # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ] - raise NotImplementedError - - def step_batch(self, batch: Optional[int] = None) -> None: - # Step the batch index, or just set it. If `batch` is specified, it - # must be the batch index from the start of training, i.e. summed over - # all epochs. - # You can call this in any order; if you don't provide 'batch', it should - # of course be called once per batch. - if batch is not None: - self.batch = batch - else: - self.batch = self.batch + 1 - self._set_lrs() - - def step_epoch(self, epoch: Optional[int] = None): - # Step the epoch index, or just set it. If you provide the 'epoch' arg, - # you should call this at the start of the epoch; if you don't provide the 'epoch' - # arg, you should call it at the end of the epoch. - if epoch is not None: - self.epoch = epoch - else: - self.epoch = self.epoch + 1 - self._set_lrs() - - def _set_lrs(self): - values = self.get_lr() - assert len(values) == len(self.optimizer.param_groups) - - for i, data in enumerate(zip(self.optimizer.param_groups, values)): - param_group, lr = data - param_group["lr"] = lr - self.print_lr(self.verbose, i, lr) - self._last_lr = [group["lr"] for group in self.optimizer.param_groups] - - def print_lr(self, is_verbose, group, lr): - """Display the current learning rate.""" - if is_verbose: - logging.info( - f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate" - f" of group {group} to {lr:.4e}." - ) - - -class Eden(LRScheduler): - """ - Eden scheduler. - The basic formula (before warmup) is: - lr = base_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 * - (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) * warmup - where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches - and then stays constant at 1. - - - E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam - - Args: - optimizer: the optimizer to change the learning rates on - lr_batches: the number of batches after which we start significantly - decreasing the learning rate, suggest 5000. - lr_epochs: the number of epochs after which we start significantly - decreasing the learning rate, suggest 6 if you plan to do e.g. - 20 to 40 epochs, but may need smaller number if dataset is huge - and you will do few epochs. - """ - - def __init__( - self, - optimizer: Optimizer, - lr_batches: Union[int, float], - lr_epochs: Union[int, float], - warmup_batches: Union[int, float] = 500.0, - warmup_start: float = 0.5, - verbose: bool = False, - ): - super(Eden, self).__init__(optimizer, verbose) - self.lr_batches = lr_batches - self.lr_epochs = lr_epochs - self.warmup_batches = warmup_batches - - assert 0.0 <= warmup_start <= 1.0, warmup_start - self.warmup_start = warmup_start - - def get_lr(self): - factor = ( - (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 - ) ** -0.25 * ( - ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 - ) - warmup_factor = ( - 1.0 - if self.batch >= self.warmup_batches - else self.warmup_start + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches) - # else 0.5 + 0.5 * (self.batch / self.warmup_batches) - ) - - return [x * factor * warmup_factor for x in self.base_lrs] - - -def _test_eden(): - m = torch.nn.Linear(100, 100) - optim = ScaledAdam(m.parameters(), lr=0.03) - - scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True) - - for epoch in range(10): - scheduler.step_epoch(epoch) # sets epoch to `epoch` - - for step in range(20): - x = torch.randn(200, 100).detach() - x.requires_grad = True - y = m(x) - dy = torch.randn(200, 100).detach() - f = (y * dy).sum() - f.backward() - - optim.step() - scheduler.step_batch() - optim.zero_grad() - - logging.info(f"last lr = {scheduler.get_last_lr()}") - logging.info(f"state dict = {scheduler.state_dict()}") - - -# This is included mostly as a baseline for ScaledAdam. -class Eve(Optimizer): - """ - Implements Eve algorithm. This is a modified version of AdamW with a special - way of setting the weight-decay / shrinkage-factor, which is designed to make the - rms of the parameters approach a particular target_rms (default: 0.1). This is - for use with networks with 'scaled' versions of modules (see scaling.py), which - will be close to invariant to the absolute scale on the parameter matrix. - - The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. - The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. - Eve is unpublished so far. - - Arguments: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay coefficient (default: 3e-4; - this value means that the weight would decay significantly after - about 3k minibatches. Is not multiplied by learning rate, but - is conditional on RMS-value of parameter being > target_rms. - target_rms (float, optional): target root-mean-square value of - parameters, if they fall below this we will stop applying weight decay. - - - .. _Adam: A Method for Stochastic Optimization: - https://arxiv.org/abs/1412.6980 - .. _Decoupled Weight Decay Regularization: - https://arxiv.org/abs/1711.05101 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ - """ - - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.98), - eps=1e-8, - weight_decay=1e-3, - target_rms=0.1, - ): - if not 0.0 <= lr: - raise ValueError("Invalid learning rate: {}".format(lr)) - if not 0.0 <= eps: - raise ValueError("Invalid epsilon value: {}".format(eps)) - if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) - if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) - if not 0 <= weight_decay <= 0.1: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) - if not 0 < target_rms <= 10.0: - raise ValueError("Invalid target_rms value: {}".format(target_rms)) - defaults = dict( - lr=lr, - betas=betas, - eps=eps, - weight_decay=weight_decay, - target_rms=target_rms, - ) - super(Eve, self).__init__(params, defaults) - - def __setstate__(self, state): - super(Eve, self).__setstate__(state) - - @torch.no_grad() - def step(self, closure=None): - """Performs a single optimization step. - - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - for p in group["params"]: - if p.grad is None: - continue - - # Perform optimization step - grad = p.grad - if grad.is_sparse: - raise RuntimeError("AdamW does not support sparse gradients") - - state = self.state[p] - - # State initialization - if len(state) == 0: - state["step"] = 0 - # Exponential moving average of gradient values - state["exp_avg"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) - # Exponential moving average of squared gradient values - state["exp_avg_sq"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) - - exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] - - beta1, beta2 = group["betas"] - - state["step"] += 1 - bias_correction1 = 1 - beta1 ** state["step"] - bias_correction2 = 1 - beta2 ** state["step"] - - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( - group["eps"] - ) - - step_size = group["lr"] / bias_correction1 - target_rms = group["target_rms"] - weight_decay = group["weight_decay"] - - if p.numel() > 1: - # avoid applying this weight-decay on "scaling factors" - # (which are scalar). - is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) - p.mul_(1 - (weight_decay * is_above_target_rms)) - - p.addcdiv_(exp_avg, denom, value=-step_size) - - if random.random() < 0.0005: - step = (exp_avg / denom) * step_size - logging.info( - f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}" - ) - - return loss - - -def _test_scaled_adam(hidden_dim: int): - import timeit - - from scaling import ScaledLinear - - E = 100 - B = 4 - T = 2 - logging.info("in test_eve_cain") - # device = torch.device('cuda') - device = torch.device("cpu") - dtype = torch.float32 - - fix_random_seed(42) - # these input_magnitudes and output_magnitudes are to test that - # Abel is working as we expect and is able to adjust scales of - # different dims differently. - input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() - output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() - - for iter in [1, 0]: - fix_random_seed(42) - Linear = torch.nn.Linear if iter == 0 else ScaledLinear - - m = torch.nn.Sequential( - Linear(E, hidden_dim), - torch.nn.PReLU(), - Linear(hidden_dim, hidden_dim), - torch.nn.PReLU(), - Linear(hidden_dim, E), - ).to(device) - - train_pairs = [ - ( - 100.0 - * torch.randn(B, T, E, device=device, dtype=dtype) - * input_magnitudes, - torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes, - ) - for _ in range(20) - ] - - if iter == 0: - optim = Eve(m.parameters(), lr=0.003) - elif iter == 1: - optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) - scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) - - start = timeit.default_timer() - avg_loss = 0.0 - for epoch in range(180): - scheduler.step_epoch() - # if epoch == 100 and iter in [2,3]: - # optim.reset_speedup() # check it doesn't crash. - - # if epoch == 130: - # opts = diagnostics.TensorDiagnosticOptions( - # 2 ** 22 - # ) # allow 4 megabytes per sub-module - # diagnostic = diagnostics.attach_diagnostics(m, opts) - - for n, (x, y) in enumerate(train_pairs): - y_out = m(x) - loss = ((y_out - y) ** 2).mean() * 100.0 - if epoch == 0 and n == 0: - avg_loss = loss.item() - else: - avg_loss = 0.98 * avg_loss + 0.02 * loss.item() - if n == 0 and epoch % 5 == 0: - # norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() - # norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() - # norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() - # norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() - # scale1 = '%.2e' % (m[0].weight_scale.exp().item()) - # scale1b = '%.2e' % (m[0].bias_scale.exp().item()) - # scale2 = '%.2e' % (m[2].weight_scale.exp().item()) - # scale2b = '%.2e' % (m[2].bias_scale.exp().item()) - lr = scheduler.get_last_lr()[0] - logging.info( - f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}" - ) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} - loss.log().backward() - optim.step() - optim.zero_grad() - scheduler.step_batch() - - # diagnostic.print_diagnostics() - - stop = timeit.default_timer() - logging.info(f"Iter={iter}, Time taken: {stop - start}") - - logging.info(f"last lr = {scheduler.get_last_lr()}") - # logging.info("state dict = ", scheduler.state_dict()) - # logging.info("optim state_dict = ", optim.state_dict()) - logging.info(f"input_magnitudes = {input_magnitudes}") - logging.info(f"output_magnitudes = {output_magnitudes}") - - -if __name__ == "__main__": - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - logging.getLogger().setLevel(logging.INFO) - import subprocess - - s = subprocess.check_output( - "git status -uno .; git log -1; git diff HEAD .", shell=True - ) - logging.info(s) - import sys - - if len(sys.argv) > 1: - hidden_dim = int(sys.argv[1]) - else: - hidden_dim = 200 - - _test_scaled_adam(hidden_dim) - _test_eden() diff --git a/egs/iwslt22_ta/ST/zipformer/optim.py b/egs/iwslt22_ta/ST/zipformer/optim.py new file mode 120000 index 000000000..5eaa3cffd --- /dev/null +++ b/egs/iwslt22_ta/ST/zipformer/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ST/zipformer/profile.py b/egs/iwslt22_ta/ST/zipformer/profile.py deleted file mode 100755 index b460b5338..000000000 --- a/egs/iwslt22_ta/ST/zipformer/profile.py +++ /dev/null @@ -1,176 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Usage: ./zipformer/profile.py -""" - -import argparse -import logging -import sentencepiece as spm -import torch - -from typing import Tuple -from torch import Tensor, nn - -from icefall.utils import make_pad_mask -from icefall.profiler import get_model_profile -from scaling import BiasNorm -from train import ( - get_encoder_embed, - get_encoder_model, - get_joiner_model, - add_model_arguments, - get_params, -) -from zipformer import BypassModule - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - add_model_arguments(parser) - - return parser - - -def _bias_norm_flops_compute(module, input, output): - assert len(input) == 1, len(input) - # estimate as layer_norm, see icefall/profiler.py - flops = input[0].numel() * 5 - module.__flops__ += int(flops) - - -def _swoosh_module_flops_compute(module, input, output): - # For SwooshL and SwooshR modules - assert len(input) == 1, len(input) - # estimate as swish/silu, see icefall/profiler.py - flops = input[0].numel() - module.__flops__ += int(flops) - - -def _bypass_module_flops_compute(module, input, output): - # For Bypass module - assert len(input) == 2, len(input) - flops = input[0].numel() * 2 - module.__flops__ += int(flops) - - -MODULE_HOOK_MAPPING = { - BiasNorm: _bias_norm_flops_compute, - BypassModule: _bypass_module_flops_compute, -} - - -class Model(nn.Module): - """A Wrapper for encoder, encoder_embed, and encoder_proj""" - - def __init__( - self, - encoder: nn.Module, - encoder_embed: nn.Module, - encoder_proj: nn.Module, - ) -> None: - super().__init__() - self.encoder = encoder - self.encoder_embed = encoder_embed - self.encoder_proj = encoder_proj - - def forward( - self, feature: Tensor, feature_lens: Tensor - ) -> Tuple[Tensor, Tensor]: - x, x_lens = self.encoder_embed(feature, feature_lens) - - src_key_padding_mask = make_pad_mask(x_lens) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - encoder_out, encoder_out_lens = self.encoder( - x, x_lens, src_key_padding_mask - ) - - encoder_out = encoder_out.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - logits = self.encoder_proj(encoder_out) - - return logits, encoder_out_lens - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - - # We only profile the encoder part - model = Model( - encoder=get_encoder_model(params), - encoder_embed=get_encoder_embed(params), - encoder_proj=get_joiner_model(params).encoder_proj, - ) - model.eval() - model.to(device) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # for 30-second input - B, T, D = 1, 3000, 80 - feature = torch.ones(B, T, D, dtype=torch.float32).to(device) - feature_lens = torch.full((B,), T, dtype=torch.int64).to(device) - - flops, params = get_model_profile( - model=model, - args=(feature, feature_lens), - module_hoop_mapping=MODULE_HOOK_MAPPING, - ) - logging.info(f"For the encoder part, params: {params}, flops: {flops}") - - -if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/iwslt22_ta/ST/zipformer/profile.py b/egs/iwslt22_ta/ST/zipformer/profile.py new file mode 120000 index 000000000..c93adbd14 --- /dev/null +++ b/egs/iwslt22_ta/ST/zipformer/profile.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/profile.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ST/zipformer/scaling.py b/egs/iwslt22_ta/ST/zipformer/scaling.py deleted file mode 100644 index 908b60938..000000000 --- a/egs/iwslt22_ta/ST/zipformer/scaling.py +++ /dev/null @@ -1,1797 +0,0 @@ -# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import Optional, Tuple, Union -import logging -import k2 -from torch.cuda.amp import custom_fwd, custom_bwd -import random -import torch -import math -import torch.nn as nn -from torch import Tensor - - -class PiecewiseLinear(object): - """ - Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with - the x values in order. x values <[initial x] or >[final x] are map to [initial y], [final y] - respectively. - """ - def __init__(self, *args): - assert len(args) >= 1, len(args) - if len(args) == 1 and isinstance(args[0], PiecewiseLinear): - self.pairs = list(args[0].pairs) - else: - self.pairs = [ (float(x), float(y)) for x,y in args ] - for (x,y) in self.pairs: - assert isinstance(x, (float, int)), type(x) - assert isinstance(y, (float, int)), type(y) - - for i in range(len(self.pairs) - 1): - assert self.pairs[i + 1][0] > self.pairs[i][0], (i, self.pairs[i], self.pairs[i + 1]) - - def __str__(self): - # e.g. 'PiecewiseLinear((0., 10.), (100., 0.))' - return f'PiecewiseLinear({str(self.pairs)[1:-1]})' - - def __call__(self, x): - if x <= self.pairs[0][0]: - return self.pairs[0][1] - elif x >= self.pairs[-1][0]: - return self.pairs[-1][1] - else: - cur_x, cur_y = self.pairs[0] - for i in range(1, len(self.pairs)): - next_x, next_y = self.pairs[i] - if x >= cur_x and x <= next_x: - return cur_y + (next_y - cur_y) * (x - cur_x) / (next_x - cur_x) - cur_x, cur_y = next_x, next_y - assert False - - def __mul__(self, alpha): - return PiecewiseLinear( - * [(x, y * alpha) for x, y in self.pairs]) - - def __add__(self, x): - if isinstance(x, (float, int)): - return PiecewiseLinear( - * [(p[0], p[1] + x) for p in self.pairs]) - s, x = self.get_common_basis(x) - return PiecewiseLinear( - * [(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)]) - - def max(self, x): - if isinstance(x, (float, int)): - x = PiecewiseLinear( (0, x) ) - s, x = self.get_common_basis(x, include_crossings=True) - return PiecewiseLinear( - * [(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]) - - def min(self, x): - if isinstance(x, float) or isinstance(x, int): - x = PiecewiseLinear( (0, x) ) - s, x = self.get_common_basis(x, include_crossings=True) - return PiecewiseLinear( - * [ (sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]) - - def __eq__(self, other): - return self.pairs == other.pairs - - def get_common_basis(self, - p: 'PiecewiseLinear', - include_crossings: bool = False): - """ - Returns (self_mod, p_mod) which are equivalent piecewise lienar - functions to self and p, but with the same x values. - - p: the other piecewise linear function - include_crossings: if true, include in the x values positions - where the functions indicate by this and p crosss. - """ - assert isinstance(p, PiecewiseLinear), type(p) - - # get sorted x-values without repetition. - x_vals = sorted(set([ x for x, _ in self.pairs ] + [ x for x, _ in p.pairs ])) - y_vals1 = [ self(x) for x in x_vals ] - y_vals2 = [ p(x) for x in x_vals ] - - if include_crossings: - extra_x_vals = [] - for i in range(len(x_vals) - 1): - if (y_vals1[i] > y_vals2[i]) != (y_vals1[i+1] > y_vals2[i+1]): - # if the two lines in this subsegment potentially cross each other.. - diff_cur = abs(y_vals1[i] - y_vals2[i]) - diff_next = abs(y_vals1[i+1] - y_vals2[i+1]) - # `pos`, between 0 and 1, gives the relative x position, - # with 0 being x_vals[i] and 1 being x_vals[i+1]. - pos = diff_cur / (diff_cur + diff_next) - extra_x_val = x_vals[i] + pos * (x_vals[i+1] - x_vals[i]) - extra_x_vals.append(extra_x_val) - if len(extra_x_vals) > 0: - x_vals = sorted(set(x_vals + extra_x_vals)) - y_vals1 = [ self(x) for x in x_vals ] - y_vals2 = [ p(x) for x in x_vals ] - return ( PiecewiseLinear(* zip(x_vals, y_vals1)), - PiecewiseLinear(* zip(x_vals, y_vals2)) ) - - -class ScheduledFloat(torch.nn.Module): - """ - This object is a torch.nn.Module only because we want it to show up in [top_level module].modules(); - it does not have a working forward() function. You are supposed to cast it to float, as - in, float(parent_module.whatever), and use it as something like a dropout prob. - - It is a floating point value whose value changes depending on the batch count of the - training loop. It is a piecewise linear function where you specifiy the (x,y) pairs - in sorted order on x; x corresponds to the batch index. For batch-index values before the - first x or after the last x, we just use the first or last y value. - - Example: - self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0) - - `default` is used when self.batch_count is not set or not in training mode or in - torch.jit scripting mode. - """ - def __init__(self, - *args, - default: float = 0.0): - super().__init__() - # self.batch_count and self.name will be written to in the training loop. - self.batch_count = None - self.name = None - self.default = default - self.schedule = PiecewiseLinear(*args) - - def extra_repr(self) -> str: - return f'batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}' - - def __float__(self): - batch_count = self.batch_count - if batch_count is None or not self.training or torch.jit.is_scripting(): - return float(self.default) - else: - ans = self.schedule(self.batch_count) - if random.random() < 0.0002: - logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}") - return ans - - def __add__(self, x): - if isinstance(x, float) or isinstance(x, int): - return ScheduledFloat(self.schedule + x, - default=self.default) - else: - return ScheduledFloat(self.schedule + x.schedule, - default=self.default+x.default) - - def max(self, x): - if isinstance(x, float) or isinstance(x, int): - return ScheduledFloat(self.schedule.max(x), - default=self.default) - else: - return ScheduledFloat(self.schedule.max(x.schedule), - default=max(self.default, x.default)) - - -FloatLike = Union[float, ScheduledFloat] - - -def random_cast_to_half(x: Tensor, - min_abs: float = 5.0e-06) -> Tensor: - """ - A randomized way of casting a floating point value to half precision. - """ - if x.dtype == torch.float16: - return x - x_abs = x.abs() - is_too_small = (x_abs < min_abs) - # for elements where is_too_small is true, random_val will contain +-min_abs with - # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations, - # for those elements]. - random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs) - return torch.where(is_too_small, random_val, x).to(torch.float16) - - -class CutoffEstimator: - """ - Estimates cutoffs of an arbitrary numerical quantity such that a specified - proportion of items will be above the cutoff on average. - - p is the proportion of items that should be above the cutoff. - """ - def __init__(self, p: float): - self.p = p - # total count of items - self.count = 0 - # total count of items that were above the cutoff - self.count_above = 0 - # initial cutoff value - self.cutoff = 0 - - def __call__(self, x: float) -> bool: - """ - Returns true if x is above the cutoff. - """ - ans = (x > self.cutoff) - self.count += 1 - if ans: - self.count_above += 1 - cur_p = self.count_above / self.count - delta_p = cur_p - self.p - if (delta_p > 0) == ans: - q = abs(delta_p) - self.cutoff = x * q + self.cutoff * (1-q) - return ans - - -class SoftmaxFunction(torch.autograd.Function): - """ - Tries to handle half-precision derivatives in a randomized way that should - be more accurate for training than the default behavior. - """ - @staticmethod - def forward(ctx, x: Tensor, dim: int): - ans = x.softmax(dim=dim) - # if x dtype is float16, x.softmax() returns a float32 because - # (presumably) that op does not support float16, and autocast - # is enabled. - if torch.is_autocast_enabled(): - ans = ans.to(torch.float16) - ctx.save_for_backward(ans) - ctx.x_dtype = x.dtype - ctx.dim = dim - return ans - - @staticmethod - def backward(ctx, ans_grad: Tensor): - ans, = ctx.saved_tensors - with torch.cuda.amp.autocast(enabled=False): - ans_grad = ans_grad.to(torch.float32) - ans = ans.to(torch.float32) - x_grad = ans_grad * ans - x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) - return x_grad, None - - -def softmax(x: Tensor, dim: int): - if not x.requires_grad or torch.jit.is_scripting(): - return x.softmax(dim=dim) - - return SoftmaxFunction.apply(x, dim) - - -class MaxEigLimiterFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x: Tensor, - coeffs: Tensor, - direction: Tensor, - channel_dim: int, - grad_scale: float) -> Tensor: - ctx.channel_dim = channel_dim - ctx.grad_scale = grad_scale - ctx.save_for_backward(x.detach(), - coeffs.detach(), - direction.detach()) - return x - - @staticmethod - def backward(ctx, x_grad, *args): - with torch.enable_grad(): - (x_orig, coeffs, new_direction) = ctx.saved_tensors - x_orig.requires_grad = True - num_channels = x_orig.shape[ctx.channel_dim] - x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) - new_direction.requires_grad = False - x = x - x.mean(dim=0) - x_var = (x ** 2).mean() - x_residual = x - coeffs * new_direction - x_residual_var = (x_residual ** 2).mean() - # `variance_proportion` is the proportion of the variance accounted for - # by the top eigen-direction. This is to be minimized. - variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) - variance_proportion.backward() - x_orig_grad = x_orig.grad - x_extra_grad = x_orig.grad * ctx.grad_scale * x_grad.norm() / (x_orig_grad.norm() + 1.0e-20) - return x_grad + x_extra_grad.detach(), None, None, None, None - - -class BiasNormFunction(torch.autograd.Function): - # This computes: - # scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp() - # return (x - bias) * scales - # (after unsqueezing the bias), but it does it in a memory-efficient way so that - # it can just store the returned value (chances are, this will also be needed for - # some other reason, related to the next operation, so we can save memory). - @staticmethod - def forward(ctx, x: Tensor, bias: Tensor, log_scale: Tensor, channel_dim: int, - store_output_for_backprop: bool) -> Tensor: - assert bias.ndim == 1 - if channel_dim < 0: - channel_dim = channel_dim + x.ndim - ctx.store_output_for_backprop = store_output_for_backprop - ctx.channel_dim = channel_dim - for _ in range(channel_dim + 1, x.ndim): - bias = bias.unsqueeze(-1) - scales = (torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5) * log_scale.exp() - ans = x * scales - ctx.save_for_backward(ans.detach() if store_output_for_backprop else x, - scales.detach(), bias.detach(), log_scale.detach()) - return ans - - @staticmethod - def backward(ctx, ans_grad: Tensor) -> Tensor: - ans_or_x, scales, bias, log_scale = ctx.saved_tensors - if ctx.store_output_for_backprop: - x = ans_or_x / scales - else: - x = ans_or_x - x = x.detach() - x.requires_grad = True - bias.requires_grad = True - log_scale.requires_grad = True - with torch.enable_grad(): - # recompute scales from x, bias and log_scale. - scales = (torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5) * log_scale.exp() - ans = x * scales - ans.backward(gradient=ans_grad) - return x.grad, bias.grad.flatten(), log_scale.grad, None, None - - -class BiasNorm(torch.nn.Module): - """ - This is intended to be a simpler, and hopefully cheaper, replacement for - LayerNorm. The observation this is based on, is that Transformer-type - networks, especially with pre-norm, sometimes seem to set one of the - feature dimensions to a large constant value (e.g. 50), which "defeats" - the LayerNorm because the output magnitude is then not strongly dependent - on the other (useful) features. Presumably the weight and bias of the - LayerNorm are required to allow it to do this. - - Instead, we give the BiasNorm a trainable bias that it can use when - computing the scale for normalization. We also give it a (scalar) - trainable scale on the output. - - - Args: - num_channels: the number of channels, e.g. 512. - channel_dim: the axis/dimension corresponding to the channel, - interprted as an offset from the input's ndim if negative. - shis is NOT the num_channels; it should typically be one of - {-2, -1, 0, 1, 2, 3}. - log_scale: the initial log-scale that we multiply the output by; this - is learnable. - log_scale_min: FloatLike, minimum allowed value of log_scale - log_scale_max: FloatLike, maximum allowed value of log_scale - store_output_for_backprop: only possibly affects memory use; recommend - to set to True if you think the output of this module is more likely - than the input of this module to be required to be stored for the - backprop. - """ - def __init__( - self, - num_channels: int, - channel_dim: int = -1, # CAUTION: see documentation. - log_scale: float = 1.0, - log_scale_min: float = -1.5, - log_scale_max: float = 1.5, - store_output_for_backprop: bool = False - ) -> None: - super(BiasNorm, self).__init__() - self.num_channels = num_channels - self.channel_dim = channel_dim - self.log_scale = nn.Parameter(torch.tensor(log_scale)) - self.bias = nn.Parameter(torch.zeros(num_channels)) - - self.log_scale_min = log_scale_min - self.log_scale_max = log_scale_max - - self.store_output_for_backprop = store_output_for_backprop - - def forward(self, x: Tensor) -> Tensor: - assert x.shape[self.channel_dim] == self.num_channels - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - channel_dim = self.channel_dim - if channel_dim < 0: - channel_dim += x.ndim - bias = self.bias - for _ in range(channel_dim + 1, x.ndim): - bias = bias.unsqueeze(-1) - scales = ((torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5) * - self.log_scale.exp()) - return x * scales - - log_scale = limit_param_value(self.log_scale, - min=float(self.log_scale_min), - max=float(self.log_scale_max), - training=self.training) - - return BiasNormFunction.apply(x, self.bias, log_scale, - self.channel_dim, - self.store_output_for_backprop) - - -def ScaledLinear(*args, - initial_scale: float = 1.0, - **kwargs) -> nn.Linear: - """ - Behaves like a constructor of a modified version of nn.Linear - that gives an easy way to set the default initial parameter scale. - - Args: - Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False. - - initial_scale: you can override this if you want to increase - or decrease the initial magnitude of the module's output - (affects the initialization of weight_scale and bias_scale). - Another option, if you want to do something like this, is - to re-initialize the parameters. - """ - ans = nn.Linear(*args, **kwargs) - with torch.no_grad(): - ans.weight[:] *= initial_scale - if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) - return ans - - -def ScaledConv1d(*args, - initial_scale: float = 1.0, - **kwargs) -> nn.Conv1d: - """ - Behaves like a constructor of a modified version of nn.Conv1d - that gives an easy way to set the default initial parameter scale. - - Args: - Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False. - - initial_scale: you can override this if you want to increase - or decrease the initial magnitude of the module's output - (affects the initialization of weight_scale and bias_scale). - Another option, if you want to do something like this, is - to re-initialize the parameters. - """ - ans = nn.Conv1d(*args, **kwargs) - with torch.no_grad(): - ans.weight[:] *= initial_scale - if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) - return ans - - -def ScaledConv2d(*args, - initial_scale: float = 1.0, - **kwargs) -> nn.Conv2d: - """ - Behaves like a constructor of a modified version of nn.Conv2d - that gives an easy way to set the default initial parameter scale. - - Args: - Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False, but: - NO PADDING-RELATED ARGS. - - initial_scale: you can override this if you want to increase - or decrease the initial magnitude of the module's output - (affects the initialization of weight_scale and bias_scale). - Another option, if you want to do something like this, is - to re-initialize the parameters. - """ - ans = nn.Conv2d(*args, **kwargs) - with torch.no_grad(): - ans.weight[:] *= initial_scale - if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) - return ans - - -class ChunkCausalDepthwiseConv1d(torch.nn.Module): - """ - Behaves like a depthwise 1d convolution, except that it is causal in - a chunkwise way, as if we had a block-triangular attention mask. - The chunk size is provided at test time (it should probably be - kept in sync with the attention mask). - - This has a little more than twice the parameters of a conventional - depthwise conv1d module: we implement it by having one - depthwise convolution, of half the width, that is causal (via - right-padding); and one depthwise convolution that is applied only - within chunks, that we multiply by a scaling factor which depends - on the position within the chunk. - - Args: - Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False. - - initial_scale: you can override this if you want to increase - or decrease the initial magnitude of the module's output - (affects the initialization of weight_scale and bias_scale). - Another option, if you want to do something like this, is - to re-initialize the parameters. - """ - def __init__(self, - channels: int, - kernel_size: int, - initial_scale: float = 1.0, - bias: bool = True): - super().__init__() - assert kernel_size % 2 == 1 - - half_kernel_size = (kernel_size + 1) // 2 - # will pad manually, on one side. - self.causal_conv = nn.Conv1d(in_channels=channels, - out_channels=channels, - groups=channels, - kernel_size=half_kernel_size, - padding=0, - bias=True) - - self.chunkwise_conv = nn.Conv1d(in_channels=channels, - out_channels=channels, - groups=channels, - kernel_size=kernel_size, - padding=kernel_size // 2, - bias=bias) - - # first row is correction factors added to the scale near the left edge of the chunk, - # second row is correction factors added to the scale near the right edge of the chunk, - # both of these are added to a default scale of 1.0. - self.chunkwise_conv_scale = nn.Parameter(torch.zeros(2, channels, kernel_size)) - self.kernel_size = kernel_size - - with torch.no_grad(): - self.causal_conv.weight[:] *= initial_scale - self.chunkwise_conv.weight[:] *= initial_scale - if bias: - torch.nn.init.uniform_(self.causal_conv.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) - - def forward(self, - x: Tensor, - chunk_size: int = -1) -> Tensor: - """ - Forward function. Args: - x: a Tensor of shape (batch_size, channels, seq_len) - chunk_size: the chunk size, in frames; does not have to divide seq_len exactly. - """ - (batch_size, num_channels, seq_len) = x.shape - - # half_kernel_size = self.kernel_size + 1 // 2 - # left_pad is half_kernel_size - 1 where half_kernel_size is the size used - # in the causal conv. It's the amount by which we must pad on the left, - # to make the convolution causal. - left_pad = self.kernel_size // 2 - - if chunk_size < 0 or chunk_size > seq_len: - chunk_size = seq_len - right_pad = -seq_len % chunk_size - - x = torch.nn.functional.pad(x, (left_pad, right_pad)) - - x_causal = self.causal_conv(x[..., :left_pad + seq_len]) - assert x_causal.shape == (batch_size, num_channels, seq_len) - - x_chunk = x[..., left_pad:] - num_chunks = x_chunk.shape[2] // chunk_size - x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks, chunk_size) - x_chunk = x_chunk.permute(0, 2, 1, 3).reshape(batch_size * num_chunks, - num_channels, chunk_size) - x_chunk = self.chunkwise_conv(x_chunk) # does not change shape - - chunk_scale = self._get_chunk_scale(chunk_size) - - x_chunk = x_chunk * chunk_scale - x_chunk = x_chunk.reshape(batch_size, num_chunks, - num_channels, chunk_size).permute(0, 2, 1, 3) - x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks * chunk_size)[..., :seq_len] - - return x_chunk + x_causal - - def _get_chunk_scale(self, chunk_size: int): - """Returns tensor of shape (num_channels, chunk_size) that will be used to - scale the output of self.chunkwise_conv.""" - left_edge = self.chunkwise_conv_scale[0] - right_edge = self.chunkwise_conv_scale[1] - if chunk_size < self.kernel_size: - left_edge = left_edge[:, :chunk_size] - right_edge = right_edge[:, -chunk_size:] - else: - t = chunk_size - self.kernel_size - channels = left_edge.shape[0] - pad = torch.zeros(channels, t, - device=left_edge.device, - dtype=left_edge.dtype) - left_edge = torch.cat((left_edge, pad), dim=-1) - right_edge = torch.cat((pad, right_edge), dim=-1) - return 1.0 + (left_edge + right_edge) - - def streaming_forward( - self, - x: Tensor, - cache: Tensor, - ) -> Tuple[Tensor, Tensor]: - """Streaming Forward function. - - Args: - x: a Tensor of shape (batch_size, channels, seq_len) - cache: cached left context of shape (batch_size, channels, left_pad) - """ - (batch_size, num_channels, seq_len) = x.shape - - # left_pad is half_kernel_size - 1 where half_kernel_size is the size used - # in the causal conv. It's the amount by which we must pad on the left, - # to make the convolution causal. - left_pad = self.kernel_size // 2 - - # Pad cache - assert cache.shape[-1] == left_pad, (cache.shape[-1], left_pad) - x = torch.cat([cache, x], dim=2) - # Update cache - cache = x[..., -left_pad:] - - x_causal = self.causal_conv(x) - assert x_causal.shape == (batch_size, num_channels, seq_len) - - x_chunk = x[..., left_pad:] - x_chunk = self.chunkwise_conv(x_chunk) # does not change shape - - chunk_scale = self._get_chunk_scale(chunk_size=seq_len) - x_chunk = x_chunk * chunk_scale - - return x_chunk + x_causal, cache - - -class BalancerFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x: Tensor, - min_mean: float, - max_mean: float, - min_rms: float, - max_rms: float, - grad_scale: float, - channel_dim: int, - ) -> Tensor: - if channel_dim < 0: - channel_dim += x.ndim - ctx.channel_dim = channel_dim - ctx.save_for_backward(x) - ctx.config = (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) - return x - - @staticmethod - def backward( - ctx, x_grad: Tensor - ) -> Tuple[Tensor, None, None, None, None, None]: - x, = ctx.saved_tensors - (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config - - try: - with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): - x = x.to(torch.float32) - x = x.detach() - x.requires_grad = True - mean_dims = [ i for i in range(x.ndim) if i != channel_dim ] - uncentered_var = (x ** 2).mean(dim=mean_dims, keepdim=True) - mean = x.mean(dim=mean_dims, keepdim=True) - stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt() - rms = uncentered_var.clamp(min=1.0e-20).sqrt() - - m = mean / stddev - # part of loss that relates to mean / stddev - m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs() - - # put a much larger scale on the RMS-max-limit loss, so that if both it and the - # m_loss are violated we fix the RMS loss first. - rms_clamped = rms.clamp(min=min_rms, max=max_rms) - r_loss = (rms_clamped / rms).log().abs() - - loss = (m_loss + r_loss) - - loss.backward(gradient=torch.ones_like(loss)) - loss_grad = x.grad - loss_grad_rms = (loss_grad ** 2).mean(dim=mean_dims, keepdim=True).sqrt().clamp(min=1.0e-20) - - loss_grad = loss_grad * (grad_scale / loss_grad_rms) - - x_grad_float = x_grad.to(torch.float32) - # scale each element of loss_grad by the absolute value of the corresponding - # element of x_grad, which we view as a noisy estimate of its magnitude for that - # (frame and dimension). later we can consider factored versions. - x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad) - x_grad = x_grad_mod.to(x_grad.dtype) - except Exception as e: - logging.info(f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue.") - - return x_grad, None, None, None, None, None, None - - -class Balancer(torch.nn.Module): - """ - Modifies the backpropped derivatives of a function to try to encourage, for - each channel, that it is positive at least a proportion `threshold` of the - time. It does this by multiplying negative derivative values by up to - (1+max_factor), and positive derivative values by up to (1-max_factor), - interpolated from 1 at the threshold to those extremal values when none - of the inputs are positive. - - Args: - num_channels: the number of channels - channel_dim: the dimension/axis corresponding to the channel, e.g. - -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. - min_positive: the minimum, per channel, of the proportion of the time - that (x > 0), below which we start to modify the derivatives. - max_positive: the maximum, per channel, of the proportion of the time - that (x > 0), above which we start to modify the derivatives. - scale_gain_factor: determines the 'gain' with which we increase the - change in gradient once the constraints on min_abs and max_abs - are violated. - min_abs: the minimum average-absolute-value difference from the mean - value per channel, which we allow, before we start to modify - the derivatives to prevent this. - max_abs: the maximum average-absolute-value difference from the mean - value per channel, which we allow, before we start to modify - the derivatives to prevent this. - prob: determines the minimum probability with which we modify the - gradients for the {min,max}_positive and {min,max}_abs constraints, - on each forward(). This is done randomly to prevent all layers - from doing it at the same time. - """ - def __init__( - self, - num_channels: int, - channel_dim: int, - min_positive: FloatLike = 0.05, - max_positive: FloatLike = 0.95, - min_abs: FloatLike = 0.2, - max_abs: FloatLike = 100.0, - grad_scale: FloatLike = 0.04, - prob: Optional[FloatLike] = None, - ): - super().__init__() - - if prob is None: - prob = ScheduledFloat((0.0, 0.5), (8000.0, 0.125), default=0.4) - self.prob = prob - # 5% of the time we will return and do nothing because memory usage is - # too high. - self.mem_cutoff = CutoffEstimator(0.05) - - # actually self.num_channels is no longer needed except for an assertion. - self.num_channels = num_channels - self.channel_dim = channel_dim - self.min_positive = min_positive - self.max_positive = max_positive - self.min_abs = min_abs - self.max_abs = max_abs - self.grad_scale = grad_scale - - def forward(self, x: Tensor) -> Tensor: - if (torch.jit.is_scripting() or not x.requires_grad or - (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))): - return _no_op(x) - - prob = float(self.prob) - if random.random() < prob: - # The following inner-functions convert from the way we historically specified - # these limitations, as limits on the absolute value and the proportion of positive - # values, to limits on the RMS value and the (mean / stddev). - def _abs_to_rms(x): - # for normally distributed data, if the expected absolute value is x, the - # expected rms value will be sqrt(pi/2) * x. - return 1.25331413732 * x - - def _proportion_positive_to_mean(x): - def _atanh(x): - eps = 1.0e-10 - # eps is to prevent crashes if x is exactly 0 or 1. - # we'll just end up returning a fairly large value. - return (math.log (1+x+eps) - math.log (1-x+eps)) / 2. - - def _approx_inverse_erf(x): - # 1 / (sqrt(pi) * ln(2)), - # see https://math.stackexchange.com/questions/321569/approximating-the-error-function-erf-by-analytical-functions - # this approximation is extremely crude and gets progressively worse for - # x very close to -1 or +1, but we mostly care about the "middle" region - # e.g. _approx_inverse_erf(0.05) = 0.0407316414078772, - # and math.erf(0.0407316414078772) = 0.045935330944660666, - # which is pretty close to 0.05. - return 0.8139535143 * _atanh(x) - # first convert x from the range 0..1 to the range -1..1 which the error - # function returns - x = -1 + (2 * x) - return _approx_inverse_erf(x) - - min_mean = _proportion_positive_to_mean(float(self.min_positive)) - max_mean = _proportion_positive_to_mean(float(self.max_positive)) - min_rms = _abs_to_rms(float(self.min_abs)) - max_rms = _abs_to_rms(float(self.max_abs)) - grad_scale = float(self.grad_scale) - - assert x.shape[self.channel_dim] == self.num_channels - - return BalancerFunction.apply( - x, min_mean, max_mean, min_rms, max_rms, grad_scale, self.channel_dim - ) - else: - return _no_op(x) - - -def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float, - name: str = None) -> Tensor: - """ - Returns x unmodified, but in backprop will put a penalty for the excess of - the absolute values of elements of x over the limit "limit". E.g. if - limit == 10.0, then if x has any values over 10 it will get a penalty. - - Caution: the value of this penalty will be affected by grad scaling used - in automatic mixed precision training. For this reasons we use this, - it shouldn't really matter, or may even be helpful; we just use this - to disallow really implausible values of scores to be given to softmax. - - The name is for randomly printed debug info. - """ - x_sign = x.sign() - over_limit = (x.abs() - limit) > 0 - # The following is a memory efficient way to penalize the absolute values of - # x that's over the limit. (The memory efficiency comes when you think - # about which items torch needs to cache for the autograd, and which ones it - # can throw away). The numerical value of aux_loss as computed here will - # actually be larger than it should be, by limit * over_limit.sum(), but it - # has the same derivative as the real aux_loss which is penalty * (x.abs() - - # limit).relu(). - aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x) - # note: we don't do sum() here on aux)_loss, but it's as if we had done - # sum() due to how with_loss() works. - x = with_loss(x, aux_loss, name) - # you must use x for something, or this will be ineffective. - return x - - -def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. - if x.ndim == 2: - return x.diag() - else: - (batch, dim, dim) = x.shape - x = x.reshape(batch, dim * dim) - x = x[:, ::dim+1] - assert x.shape == (batch, dim) - return x - - -def _whitening_metric(x: Tensor, - num_groups: int): - """ - Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of - of the centered feature covariance are the same within each group's covariance matrix - and also between groups. - Args: - x: a Tensor of shape (*, num_channels) - num_groups: the number of groups of channels, a number >=1 that divides num_channels - Returns: - Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and - greater than 1.0 otherwise. - """ - assert x.dtype != torch.float16 - x = x.reshape(-1, x.shape[-1]) - (num_frames, num_channels) = x.shape - assert num_channels % num_groups == 0 - channels_per_group = num_channels // num_groups - x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1) - # x now has shape (num_groups, num_frames, channels_per_group) - # subtract the mean so we use the centered, not uncentered, covariance. - # My experience has been that when we "mess with the gradients" like this, - # it's better not do anything that tries to move the mean around, because - # that can easily cause instability. - x = x - x.mean(dim=1, keepdim=True) - # x_covar: (num_groups, channels_per_group, channels_per_group) - x_covar = torch.matmul(x.transpose(1, 2), x) - x_covar_mean_diag = _diag(x_covar).mean() - # the following expression is what we'd get if we took the matrix product - # of each covariance and measured the mean of its trace, i.e. - # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). - x_covarsq_mean_diag = (x_covar ** 2).sum() / (num_groups * channels_per_group) - # this metric will be >= 1.0; the larger it is, the less 'white' the data was. - metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20) - return metric - - -class WhiteningPenaltyFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, - x: Tensor, - module: nn.Module) -> Tensor: - ctx.save_for_backward(x) - ctx.module = module - return x - - @staticmethod - def backward(ctx, - x_grad: Tensor): - x_orig, = ctx.saved_tensors - w = ctx.module - - try: - with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): - x_detached = x_orig.to(torch.float32).detach() - x_detached.requires_grad = True - - metric = _whitening_metric(x_detached, w.num_groups) - - if random.random() < 0.005 or __name__ == "__main__": - logging.info(f"Whitening: name={w.name}, num_groups={w.num_groups}, num_channels={x_orig.shape[-1]}, " - f"metric={metric.item():.2f} vs. limit={float(w.whitening_limit)}") - - if metric < float(w.whitening_limit): - w.prob = w.min_prob - return x_grad, None - else: - w.prob = w.max_prob - metric.backward() - penalty_grad = x_detached.grad - scale = w.grad_scale * (x_grad.to(torch.float32).norm() / - (penalty_grad.norm() + 1.0e-20)) - penalty_grad = penalty_grad * scale - return x_grad + penalty_grad.to(x_grad.dtype), None - except Exception as e: - logging.info(f"Caught exception in Whiten backward: {e}, size={list(x_grad.shape)}, will continue.") - return x_grad, None - - -class Whiten(nn.Module): - def __init__( - self, - num_groups: int, - whitening_limit: FloatLike, - prob: Union[float, Tuple[float,float]], - grad_scale: FloatLike): - """ - Args: - num_groups: the number of groups to divide the channel dim into before - whitening. We will attempt to make the feature covariance - within each group, after mean subtraction, as "white" as possible, - while having the same trace across all groups. - whitening_limit: a value greater than 1.0, that dictates how much - freedom we have to violate the constraints. 1.0 would mean perfectly - white, with exactly the same trace across groups; larger values - give more freedom. E.g. 2.0. - prob: the probability with which we apply the gradient modification - (also affects the grad scale). May be supplied as a float, - or as a pair (min_prob, max_prob) - - grad_scale: determines the scale on the gradient term from this object, - relative to the rest of the gradient on the attention weights. - E.g. 0.02 (you may want to use smaller values than this if prob is large) - """ - super(Whiten, self).__init__() - assert num_groups >= 1 - assert float(whitening_limit) >= 1 - assert grad_scale >= 0 - self.num_groups = num_groups - self.whitening_limit = whitening_limit - self.grad_scale = grad_scale - - if isinstance(prob, float): - prob = (prob, prob) - (self.min_prob, self.max_prob) = prob - assert 0 < self.min_prob <= self.max_prob <= 1 - self.prob = self.max_prob - self.name = None # will be set in training loop - - def forward(self, - x: Tensor) -> Tensor: - """ - In the forward pass, this function just returns the input unmodified. - In the backward pass, it will modify the gradients to ensure that the - distribution in each group has close to (lambda times I) as the covariance - after mean subtraction, with the same lambda across groups. - For whitening_limit > 1, there will be more freedom to violate this - constraint. - - Args: - x: the input of shape (*, num_channels) - - Returns: - x, unmodified. You should make sure - you use the returned value, or the graph will be freed - and nothing will happen in backprop. - """ - grad_scale = float(self.grad_scale) - if not x.requires_grad or random.random() > self.prob or grad_scale == 0: - return _no_op(x) - else: - return WhiteningPenaltyFunction.apply(x, self) - - -class WithLoss(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, y: Tensor, name: str): - ctx.y_shape = y.shape - if random.random() < 0.002 and name is not None: - loss_sum = y.sum().item() - logging.info(f"WithLoss: name={name}, loss-sum={loss_sum:.3e}") - return x - - @staticmethod - def backward(ctx, ans_grad: Tensor): - return ans_grad, torch.ones(ctx.y_shape, - dtype=ans_grad.dtype, - device=ans_grad.device), None - - -def with_loss(x, y, name): - # returns x but adds y.sum() to the loss function. - return WithLoss.apply(x, y, name) - - -class ScaleGradFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, alpha: float) -> Tensor: - ctx.alpha = alpha - return x - - @staticmethod - def backward(ctx, grad: Tensor): - return grad * ctx.alpha, None - - -def scale_grad(x: Tensor, alpha: float): - return ScaleGradFunction.apply(x, alpha) - - -class ScaleGrad(nn.Module): - def __init__(self, alpha: float): - super().__init__() - self.alpha = alpha - - def forward(self, x: Tensor) -> Tensor: - if torch.jit.is_scripting() or not self.training: - return x - return scale_grad(x, self.alpha) - - -class LimitParamValue(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, min: float, max: float): - ctx.save_for_backward(x) - assert max >= min - ctx.min = min - ctx.max = max - return x - - @staticmethod - def backward(ctx, x_grad: Tensor): - x, = ctx.saved_tensors - # where x < ctx.min, ensure all grads are negative (this will tend to make - # x more positive). - x_grad = x_grad * torch.where(torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0) - # where x > ctx.max, ensure all grads are positive (this will tend to make - # x more negative). - x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0) - return x_grad, None, None - - -def limit_param_value(x: Tensor, - min: float, max: float, - prob: float = 0.6, - training: bool = True): - # You apply this to (typically) an nn.Parameter during training to ensure that its - # (elements mostly) stays within a supplied range. This is done by modifying the - # gradients in backprop. - # It's not necessary to do this on every batch: do it only some of the time, - # to save a little time. - if training and random.random() < prob: - return LimitParamValue.apply(x, min, max) - else: - return x - - -def _no_op(x: Tensor) -> Tensor: - if (torch.jit.is_scripting()): - return x - else: - # a no-op function that will have a node in the autograd graph, - # to avoid certain bugs relating to backward hooks - return x.chunk(1, dim=-1)[0] - - -class Identity(torch.nn.Module): - def __init__(self): - super(Identity, self).__init__() - - def forward(self, x): - return _no_op(x) - - -class DoubleSwishFunction(torch.autograd.Function): - """ - double_swish(x) = x * torch.sigmoid(x-1) - - This is a definition, originally motivated by its close numerical - similarity to swish(swish(x)), where swish(x) = x * sigmoid(x). - - Memory-efficient derivative computation: - double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) - double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). - Now, s'(x) = s(x) * (1-s(x)). - double_swish'(x) = x * s'(x) + s(x). - = x * s(x) * (1-s(x)) + s(x). - = double_swish(x) * (1-s(x)) + s(x) - ... so we just need to remember s(x) but not x itself. - """ - - @staticmethod - def forward(ctx, x: Tensor) -> Tensor: - requires_grad = x.requires_grad - if x.dtype == torch.float16: - x = x.to(torch.float32) - - s = torch.sigmoid(x - 1.0) - y = x * s - - if requires_grad: - deriv = (y * (1 - s) + s) - - # notes on derivative of x * sigmoid(x - 1): - # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 - # min \simeq -0.043638. Take floor as -0.044 so it's a lower bund - # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound. - # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which - # floors), should be expectation-preserving. - floor = -0.044 - ceil = 1.2 - d_scaled = ((deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv)) - if __name__ == "__main__": - # for self-testing only. - assert d_scaled.min() >= 0.0 - assert d_scaled.max() < 256.0 - d_int = d_scaled.to(torch.uint8) - ctx.save_for_backward(d_int) - if x.dtype == torch.float16 or torch.is_autocast_enabled(): - y = y.to(torch.float16) - return y - - @staticmethod - def backward(ctx, y_grad: Tensor) -> Tensor: - d, = ctx.saved_tensors - # the same constants as used in forward pass. - floor = -0.043637 - ceil = 1.2 - - d = (d * ((ceil - floor) / 255.0) + floor) - return y_grad * d - - -class DoubleSwish(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x: Tensor) -> Tensor: - """Return double-swish activation function which is an approximation to Swish(Swish(x)), - that we approximate closely with x * sigmoid(x-1). - """ - if torch.jit.is_scripting(): - return x * torch.sigmoid(x - 1.0) - return DoubleSwishFunction.apply(x) - - -# Dropout2 is just like normal dropout, except it supports schedules on the dropout rates. -class Dropout2(nn.Module): - def __init__(self, p: FloatLike): - super().__init__() - self.p = p - - def forward(self, x: Tensor) -> Tensor: - return torch.nn.functional.dropout(x, - p=float(self.p), - training=self.training) - - -class MulForDropout3(torch.autograd.Function): - # returns (x * y * alpha) where alpha is a float and y doesn't require - # grad and is zero-or-one. - @staticmethod - @custom_fwd - def forward(ctx, x, y, alpha): - assert not y.requires_grad - ans = x * y * alpha - ctx.save_for_backward(ans) - ctx.alpha = alpha - return ans - - @staticmethod - @custom_bwd - def backward(ctx, ans_grad): - ans, = ctx.saved_tensors - x_grad = ctx.alpha * ans_grad * (ans != 0) - return x_grad, None, None - - -# Dropout3 is just like normal dropout, except it supports schedules on the dropout rates, -# and it lets you choose one dimension to share the dropout mask over -class Dropout3(nn.Module): - def __init__(self, p: FloatLike, shared_dim: int): - super().__init__() - self.p = p - self.shared_dim = shared_dim - - def forward(self, x: Tensor) -> Tensor: - p = float(self.p) - if not self.training or p == 0: - return _no_op(x) - scale = 1.0 / (1 - p) - rand_shape = list(x.shape) - rand_shape[self.shared_dim] = 1 - mask = torch.rand(*rand_shape, device=x.device) > p - ans = MulForDropout3.apply(x, mask, scale) - return ans - - -class SwooshLFunction(torch.autograd.Function): - """ - swoosh(x) = log(1 + exp(x-4)) - 0.08*x - 0.035 - """ - - @staticmethod - def forward(ctx, x: Tensor) -> Tensor: - requires_grad = x.requires_grad - if x.dtype == torch.float16: - x = x.to(torch.float32) - - zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - - coeff = -0.08 - - with torch.cuda.amp.autocast(enabled=False): - with torch.enable_grad(): - x = x.detach() - x.requires_grad = True - y = torch.logaddexp(zero, x - 4.0) + coeff * x - 0.035 - - if not requires_grad: - return y - - y.backward(gradient = torch.ones_like(y)) - - grad = x.grad - floor = coeff - ceil = 1.0 + coeff + 0.005 - - d_scaled = ((grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(grad)) - if __name__ == "__main__": - # for self-testing only. - assert d_scaled.min() >= 0.0 - assert d_scaled.max() < 256.0 - - d_int = d_scaled.to(torch.uint8) - ctx.save_for_backward(d_int) - if x.dtype == torch.float16 or torch.is_autocast_enabled(): - y = y.to(torch.float16) - return y - - @staticmethod - def backward(ctx, y_grad: Tensor) -> Tensor: - d, = ctx.saved_tensors - # the same constants as used in forward pass. - - coeff = -0.08 - floor = coeff - ceil = 1.0 + coeff + 0.005 - d = (d * ((ceil - floor) / 255.0) + floor) - return (y_grad * d) - - -class SwooshL(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: - """Return Swoosh-L activation. - """ - if torch.jit.is_scripting(): - zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - return torch.logaddexp(zero, x - 4.0) - 0.08 * x - 0.035 - if not x.requires_grad: - return k2.swoosh_l_forward(x) - else: - return k2.swoosh_l(x) - # return SwooshLFunction.apply(x) - - -class SwooshRFunction(torch.autograd.Function): - """ - swoosh(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687 - - derivatives are between -0.08 and 0.92. - """ - - @staticmethod - def forward(ctx, x: Tensor) -> Tensor: - requires_grad = x.requires_grad - - if x.dtype == torch.float16: - x = x.to(torch.float32) - - zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - - with torch.cuda.amp.autocast(enabled=False): - with torch.enable_grad(): - x = x.detach() - x.requires_grad = True - y = torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687 - - if not requires_grad: - return y - y.backward(gradient = torch.ones_like(y)) - - grad = x.grad - floor = -0.08 - ceil = 0.925 - - d_scaled = ((grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(grad)) - if __name__ == "__main__": - # for self-testing only. - assert d_scaled.min() >= 0.0 - assert d_scaled.max() < 256.0 - - d_int = d_scaled.to(torch.uint8) - ctx.save_for_backward(d_int) - if x.dtype == torch.float16 or torch.is_autocast_enabled(): - y = y.to(torch.float16) - return y - - @staticmethod - def backward(ctx, y_grad: Tensor) -> Tensor: - d, = ctx.saved_tensors - # the same constants as used in forward pass. - floor = -0.08 - ceil = 0.925 - d = (d * ((ceil - floor) / 255.0) + floor) - return (y_grad * d) - - -class SwooshR(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: - """Return Swoosh-R activation. - """ - if torch.jit.is_scripting(): - zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - return torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687 - if not x.requires_grad: - return k2.swoosh_r_forward(x) - else: - return k2.swoosh_r(x) - # return SwooshRFunction.apply(x) - - -# simple version of SwooshL that does not redefine the backprop, used in -# ActivationDropoutAndLinearFunction. -def SwooshLForward(x: Tensor): - x_offset = x - 4.0 - log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) - log_sum = torch.where(log_sum == float('inf'), x_offset, log_sum) - return log_sum - 0.08 * x - 0.035 - - -# simple version of SwooshR that does not redefine the backprop, used in -# ActivationDropoutAndLinearFunction. -def SwooshRForward(x: Tensor): - x_offset = x - 1.0 - log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) - log_sum = torch.where(log_sum == float('inf'), x_offset, log_sum) - return log_sum - 0.08 * x - 0.313261687 - - -class ActivationDropoutAndLinearFunction(torch.autograd.Function): - @staticmethod - @custom_fwd - def forward(ctx, - x: Tensor, - weight: Tensor, - bias: Optional[Tensor], - activation: str, - dropout_p: float, - dropout_shared_dim: Optional[int]): - if dropout_p != 0.0: - dropout_shape = list(x.shape) - if dropout_shared_dim is not None: - dropout_shape[dropout_shared_dim] = 1 - # else it won't be very memory efficient. - dropout_mask = ((1.0 / (1.0 - dropout_p)) * - (torch.rand(*dropout_shape, - device=x.device, dtype=x.dtype) > dropout_p)) - else: - dropout_mask = None - - ctx.save_for_backward(x, weight, bias, dropout_mask) - - ctx.activation = activation - - forward_activation_dict = { - 'SwooshL': k2.swoosh_l_forward, - 'SwooshR': k2.swoosh_r_forward - } - # it will raise a KeyError if this fails. This will be an error. We let it - # propagate to the user. - activation_func = forward_activation_dict[activation] - x = activation_func(x) - if dropout_mask is not None: - x = x * dropout_mask - x = torch.nn.functional.linear(x, weight, bias) - return x - - @staticmethod - @custom_bwd - def backward(ctx, ans_grad: Tensor): - saved = ctx.saved_tensors - (x, weight, bias, dropout_mask) = saved - - forward_and_deriv_activation_dict = { - 'SwooshL': k2.swoosh_l_forward_and_deriv, - 'SwooshR': k2.swoosh_r_forward_and_deriv - } - # the following lines a KeyError if the activation is unrecognized. - # This will be an error. We let it propagate to the user. - func = forward_and_deriv_activation_dict[ctx.activation] - - y, func_deriv = func(x) - if dropout_mask is not None: - y = y * dropout_mask - # now compute derivative of y w.r.t. weight and bias.. - # y: (..., in_channels), ans_grad: (..., out_channels), - (out_channels, in_channels) = weight.shape - - in_channels = y.shape[-1] - g = ans_grad.reshape(-1, out_channels) - weight_deriv = torch.matmul(g.t(), - y.reshape(-1, in_channels)) - y_deriv = torch.matmul(ans_grad, weight) - bias_deriv = None if bias is None else g.sum(dim=0) - x_deriv = y_deriv * func_deriv - if dropout_mask is not None: - # order versus func_deriv does not matter - x_deriv = x_deriv * dropout_mask - - return x_deriv, weight_deriv, bias_deriv, None, None, None - - -class ActivationDropoutAndLinear(torch.nn.Module): - """ - This merges an activation function followed by dropout and then a nn.Linear module; - it does so in a memory efficient way so that it only stores the input to the whole - module. If activation == SwooshL and dropout_shared_dim != None, this will be - equivalent to: - nn.Sequential(SwooshL(), - Dropout3(dropout_p, shared_dim=dropout_shared_dim), - ScaledLinear(in_channels, out_channels, bias=bias, - initial_scale=initial_scale)) - If dropout_shared_dim is None, the dropout would be equivalent to - Dropout2(dropout_p). Note: Dropout3 will be more memory efficient as the dropout - mask is smaller. - - Args: - in_channels: number of input channels, e.g. 256 - out_channels: number of output channels, e.g. 256 - bias: if true, have a bias - activation: the activation function, for now just support SwooshL. - dropout_p: the dropout probability or schedule (happens after nonlinearity). - dropout_shared_dim: the dimension, if any, across which the dropout mask is - shared (e.g. the time dimension). If None, this may be less memory - efficient if there are modules before this one that cache the input - for their backprop (e.g. Balancer or Whiten). - """ - def __init__(self, - in_channels: int, - out_channels: int, - bias: bool = True, - activation: str = 'SwooshL', - dropout_p: FloatLike = 0.0, - dropout_shared_dim: Optional[int] = -1, - initial_scale: float = 1.0): - super().__init__() - # create a temporary module of nn.Linear that we'll steal the - # weights and bias from - l = ScaledLinear(in_channels, out_channels, - bias=bias, - initial_scale=initial_scale) - - self.weight = l.weight - # register_parameter properly handles making it a parameter when l.bias - # is None. I think there is some reason for doing it this way rather - # than just setting it to None but I don't know what it is, maybe - # something to do with exporting the module.. - self.register_parameter('bias', l.bias) - - self.activation = activation - self.dropout_p = dropout_p - self.dropout_shared_dim = dropout_shared_dim - - def forward(self, - x: Tensor): - if torch.jit.is_scripting() or torch.jit.is_tracing(): - if self.activation == 'SwooshL': - x = SwooshLForward(x) - elif self.activation == "SwooshR": - x = SwooshRForward(x) - else: - assert False, self.activation - return torch.nn.functional.linear(x, - self.weight, - self.bias) - - return ActivationDropoutAndLinearFunction.apply( - x, self.weight, self.bias, self.activation, - float(self.dropout_p), self.dropout_shared_dim) - - -def convert_num_channels(x: Tensor, num_channels: int) -> Tensor: - if num_channels <= x.shape[-1]: - return x[..., :num_channels] - else: - shape = list(x.shape) - shape[-1] = num_channels - shape[-1] - zeros = torch.zeros(shape, dtype=x.dtype, device=x.device) - return torch.cat((x, zeros), dim=-1) - - -def _test_whiten(): - for proportion in [0.1, 0.5, 10.0]: - logging.info(f"_test_whiten(): proportion = {proportion}") - x = torch.randn(100, 128) - direction = torch.randn(128) - coeffs = torch.randn(100, 1) - x += proportion * direction * coeffs - - x.requires_grad = True - - m = Whiten(1, # num_groups - 5.0, # whitening_limit, - prob=1.0, - grad_scale=0.1) # grad_scale - - for _ in range(4): - y = m(x) - - y_grad = torch.randn_like(x) - y.backward(gradient=y_grad) - - if proportion < 0.2: - assert torch.allclose(x.grad, y_grad) - elif proportion > 1.0: - assert not torch.allclose(x.grad, y_grad) - - -def _test_balancer_sign(): - probs = torch.arange(0, 1, 0.01) - N = 1000 - x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0) - x = x.detach() - x.requires_grad = True - m = Balancer( - probs.numel(), - channel_dim=0, - min_positive=0.05, - max_positive=0.95, - min_abs=0.0, - prob=1.0, - ) - - y_grad = torch.sign(torch.randn(probs.numel(), N)) - - y = m(x) - y.backward(gradient=y_grad) - print("_test_balancer_sign: x = ", x) - print("_test_balancer_sign: y grad = ", y_grad) - print("_test_balancer_sign: x grad = ", x.grad) - - -def _test_balancer_magnitude(): - magnitudes = torch.arange(0, 1, 0.01) - N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( - -1 - ) - x = x.detach() - x.requires_grad = True - m = Balancer( - magnitudes.numel(), - channel_dim=0, - min_positive=0.0, - max_positive=1.0, - min_abs=0.2, - max_abs=0.7, - prob=1.0, - ) - - y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) - - y = m(x) - y.backward(gradient=y_grad) - print("_test_balancer_magnitude: x = ", x) - print("_test_balancer_magnitude: y grad = ", y_grad) - print("_test_balancer_magnitude: x grad = ", x.grad) - - -def _test_double_swish_deriv(): - x = torch.randn(10, 12, dtype=torch.double) * 3.0 - x.requires_grad = True - m = DoubleSwish() - - tol = ((1.2-(-0.043637))/255.0) - torch.autograd.gradcheck(m, x, atol=tol) - - # for self-test. - x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 - x.requires_grad = True - y = m(x) - - -def _test_swooshl_deriv(): - x = torch.randn(10, 12, dtype=torch.double) * 3.0 - x.requires_grad = True - m = SwooshL() - - tol = (1.0 / 255.0) - torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) - - # for self-test. - x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 - x.requires_grad = True - y = m(x) - - -def _test_swooshr_deriv(): - x = torch.randn(10, 12, dtype=torch.double) * 3.0 - x.requires_grad = True - m = SwooshR() - - tol = (1.0 / 255.0) - torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) - - # for self-test. - x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 - x.requires_grad = True - y = m(x) - - -def _test_softmax(): - a = torch.randn(2, 10, dtype=torch.float64) - b = a.clone() - a.requires_grad = True - b.requires_grad = True - a.softmax(dim=1)[:,0].sum().backward() - print("a grad = ", a.grad) - softmax(b, dim=1)[:,0].sum().backward() - print("b grad = ", b.grad) - assert torch.allclose(a.grad, b.grad) - - -def _test_piecewise_linear(): - p = PiecewiseLinear( (0, 10.0) ) - for x in [-100, 0, 100]: - assert p(x) == 10.0 - p = PiecewiseLinear( (0, 10.0), (1, 0.0) ) - for x, y in [ (-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0) ]: - print("x, y = ", x, y) - assert p(x) == y, (x, p(x), y) - - q = PiecewiseLinear((0.5, 15.0), (0.6, 1.0)) - x_vals = [ -1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0 ] - pq = p.max(q) - for x in x_vals: - y1 = max(p(x), q(x)) - y2 = pq(x) - assert abs(y1 - y2) < 0.001 - pq = p.min(q) - for x in x_vals: - y1 = min(p(x), q(x)) - y2 = pq(x) - assert abs(y1 - y2) < 0.001 - pq = p + q - for x in x_vals: - y1 = p(x) + q(x) - y2 = pq(x) - assert abs(y1 - y2) < 0.001 - - -def _test_activation_dropout_and_linear(): - in_channels = 20 - out_channels = 30 - - for bias in [True, False]: - # actually we don't test for dropout_p != 0.0 because forward functions will give - # different answers. This is because we are using the k2 implementation of - # swoosh_l an swoosh_r inside SwooshL() and SwooshR(), and they call randn() - # internally, messing up the random state. - for dropout_p in [0.0]: - for activation in ['SwooshL', 'SwooshR']: - m1 = nn.Sequential(SwooshL() if activation == 'SwooshL' else SwooshR(), - Dropout3(p=dropout_p, shared_dim=-1), - ScaledLinear(in_channels, out_channels, bias=bias, - initial_scale=0.5)) - m2 = ActivationDropoutAndLinear(in_channels, out_channels, - bias=bias, initial_scale=0.5, - activation=activation, - dropout_p=dropout_p) - with torch.no_grad(): - m2.weight[:] = m1[2].weight - if bias: - m2.bias[:] = m1[2].bias - # make sure forward gives same result. - x1 = torch.randn(10, in_channels) - x1.requires_grad = True - - # TEMP. - assert torch.allclose(SwooshRFunction.apply(x1), - SwooshRForward(x1), - atol=1.0e-03) - - x2 = x1.clone().detach() - x2.requires_grad = True - seed = 10 - torch.manual_seed(seed) - y1 = m1(x1) - y_grad = torch.randn_like(y1) - y1.backward(gradient=y_grad) - torch.manual_seed(seed) - y2 = m2(x2) - y2.backward(gradient=y_grad) - - print(f"bias = {bias}, dropout_p = {dropout_p}, activation = {activation}") - print("y1 = ", y1) - print("y2 = ", y2) - assert torch.allclose(y1, y2, atol=0.02) - assert torch.allclose(m1[2].weight.grad, m2.weight.grad, - atol=1.0e-05) - if bias: - assert torch.allclose(m1[2].bias.grad, m2.bias.grad, - atol=1.0e-05) - print("x1.grad = ", x1.grad) - print("x2.grad = ", x2.grad) - - def isclose(a, b): - # return true if cosine similarity is > 0.9. - return (a * b).sum() > 0.9 * ((a**2).sum() * (b**2).sum()).sqrt() - # the SwooshL() implementation has a noisy gradient due to 1-byte - # storage of it. - assert isclose(x1.grad, x2.grad) - - -if __name__ == "__main__": - logging.getLogger().setLevel(logging.INFO) - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - _test_piecewise_linear() - _test_softmax() - _test_whiten() - _test_balancer_sign() - _test_balancer_magnitude() - _test_double_swish_deriv() - _test_swooshr_deriv() - _test_swooshl_deriv() - _test_activation_dropout_and_linear() diff --git a/egs/iwslt22_ta/ST/zipformer/scaling.py b/egs/iwslt22_ta/ST/zipformer/scaling.py new file mode 120000 index 000000000..6f398f431 --- /dev/null +++ b/egs/iwslt22_ta/ST/zipformer/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ST/zipformer/scaling_converter.py b/egs/iwslt22_ta/ST/zipformer/scaling_converter.py deleted file mode 100644 index 683a03461..000000000 --- a/egs/iwslt22_ta/ST/zipformer/scaling_converter.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file replaces various modules in a model. -Specifically, ActivationBalancer is replaced with an identity operator; -Whiten is also replaced with an identity operator; -BasicNorm is replaced by a module with `exp` removed. -""" - -import copy -from typing import List, Tuple - -import torch -import torch.nn as nn -from scaling import Balancer, Dropout3, ScaleGrad, Whiten - - -# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa -# get_submodule was added to nn.Module at v1.9.0 -def get_submodule(model, target): - if target == "": - return model - atoms: List[str] = target.split(".") - mod: torch.nn.Module = model - for item in atoms: - if not hasattr(mod, item): - raise AttributeError( - mod._get_name() + " has no " "attribute `" + item + "`" - ) - mod = getattr(mod, item) - if not isinstance(mod, torch.nn.Module): - raise AttributeError("`" + item + "` is not " "an nn.Module") - return mod - - -def convert_scaled_to_non_scaled( - model: nn.Module, - inplace: bool = False, - is_pnnx: bool = False, -): - """ - Args: - model: - The model to be converted. - inplace: - If True, the input model is modified inplace. - If False, the input model is copied and we modify the copied version. - is_pnnx: - True if we are going to export the model for PNNX. - Return: - Return a model without scaled layers. - """ - if not inplace: - model = copy.deepcopy(model) - - d = {} - for name, m in model.named_modules(): - if isinstance(m, (Balancer, Dropout3, ScaleGrad, Whiten)): - d[name] = nn.Identity() - - for k, v in d.items(): - if "." in k: - parent, child = k.rsplit(".", maxsplit=1) - setattr(get_submodule(model, parent), child, v) - else: - setattr(model, k, v) - - return model diff --git a/egs/iwslt22_ta/ST/zipformer/scaling_converter.py b/egs/iwslt22_ta/ST/zipformer/scaling_converter.py new file mode 120000 index 000000000..b0ecee05e --- /dev/null +++ b/egs/iwslt22_ta/ST/zipformer/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ST/zipformer/subsampling.py b/egs/iwslt22_ta/ST/zipformer/subsampling.py deleted file mode 100644 index 47403f13c..000000000 --- a/egs/iwslt22_ta/ST/zipformer/subsampling.py +++ /dev/null @@ -1,407 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Daniel Povey, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Tuple -import warnings - -import torch -from torch import Tensor, nn -from scaling import ( - Balancer, - BiasNorm, - Dropout3, - FloatLike, - Optional, - ScaledConv2d, - ScaleGrad, - ScheduledFloat, - SwooshL, - SwooshR, - Whiten, -) - - -class ConvNeXt(nn.Module): - """ - Our interpretation of the ConvNeXt module as used in https://arxiv.org/pdf/2206.14747.pdf - """ - - def __init__( - self, - channels: int, - hidden_ratio: int = 3, - kernel_size: Tuple[int, int] = (7, 7), - layerdrop_rate: FloatLike = None, - ): - super().__init__() - self.padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2) - hidden_channels = channels * hidden_ratio - if layerdrop_rate is None: - layerdrop_rate = ScheduledFloat((0.0, 0.2), (20000.0, 0.015)) - self.layerdrop_rate = layerdrop_rate - - self.depthwise_conv = nn.Conv2d( - in_channels=channels, - out_channels=channels, - groups=channels, - kernel_size=kernel_size, - padding=self.padding, - ) - - self.pointwise_conv1 = nn.Conv2d( - in_channels=channels, out_channels=hidden_channels, kernel_size=1 - ) - - self.hidden_balancer = Balancer( - hidden_channels, - channel_dim=1, - min_positive=0.3, - max_positive=1.0, - min_abs=0.75, - max_abs=5.0, - ) - - self.activation = SwooshL() - self.pointwise_conv2 = ScaledConv2d( - in_channels=hidden_channels, - out_channels=channels, - kernel_size=1, - initial_scale=0.01, - ) - - self.out_balancer = Balancer( - channels, - channel_dim=1, - min_positive=0.4, - max_positive=0.6, - min_abs=1.0, - max_abs=6.0, - ) - self.out_whiten = Whiten( - num_groups=1, - whitening_limit=5.0, - prob=(0.025, 0.25), - grad_scale=0.01, - ) - - def forward(self, x: Tensor) -> Tensor: - if torch.jit.is_scripting() or not self.training: - return self.forward_internal(x) - layerdrop_rate = float(self.layerdrop_rate) - - if layerdrop_rate != 0.0: - batch_size = x.shape[0] - mask = ( - torch.rand( - (batch_size, 1, 1, 1), dtype=x.dtype, device=x.device - ) - > layerdrop_rate - ) - else: - mask = None - # turns out this caching idea does not work with --world-size > 1 - # return caching_eval(self.forward_internal, x, mask) - return self.forward_internal(x, mask) - - def forward_internal( - self, x: Tensor, layer_skip_mask: Optional[Tensor] = None - ) -> Tensor: - """ - x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs) - - The returned value has the same shape as x. - """ - bypass = x - x = self.depthwise_conv(x) - x = self.pointwise_conv1(x) - x = self.hidden_balancer(x) - x = self.activation(x) - x = self.pointwise_conv2(x) - - if layer_skip_mask is not None: - x = x * layer_skip_mask - - x = bypass + x - x = self.out_balancer(x) - x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last - x = self.out_whiten(x) - x = x.transpose(1, 3) # (N, C, H, W) - - return x - - def streaming_forward( - self, - x: Tensor, - cached_left_pad: Tensor, - ) -> Tuple[Tensor, Tensor]: - """ - Args: - x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs) - cached_left_pad: (batch_size, num_channels, left_pad, num_freqs) - - Returns: - - The returned value has the same shape as x. - - Updated cached_left_pad. - """ - padding = self.padding - - # The length without right padding for depth-wise conv - T = x.size(2) - padding[0] - - bypass = x[:, :, :T, :] - - # Pad left side - assert cached_left_pad.size(2) == padding[0], ( - cached_left_pad.size(2), - padding[0], - ) - x = torch.cat([cached_left_pad, x], dim=2) - # Update cached left padding - cached_left_pad = x[:, :, T : padding[0] + T, :] - - # depthwise_conv - x = torch.nn.functional.conv2d( - x, - weight=self.depthwise_conv.weight, - bias=self.depthwise_conv.bias, - padding=(0, padding[1]), - groups=self.depthwise_conv.groups, - ) - x = self.pointwise_conv1(x) - x = self.hidden_balancer(x) - x = self.activation(x) - x = self.pointwise_conv2(x) - - x = bypass + x - return x, cached_left_pad - - -class Conv2dSubsampling(nn.Module): - """Convolutional 2D subsampling (to 1/2 length). - - Convert an input of shape (N, T, idim) to an output - with shape (N, T', odim), where - T' = (T-3)//2 - 2 == (T-7)//2 - - It is based on - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - layer1_channels: int = 8, - layer2_channels: int = 32, - layer3_channels: int = 128, - dropout: FloatLike = 0.1, - ) -> None: - """ - Args: - in_channels: - Number of channels in. The input shape is (N, T, in_channels). - Caution: It requires: T >=7, in_channels >=7 - out_channels - Output dim. The output shape is (N, (T-3)//2, out_channels) - layer1_channels: - Number of channels in layer1 - layer1_channels: - Number of channels in layer2 - bottleneck: - bottleneck dimension for 1d squeeze-excite - """ - assert in_channels >= 7 - super().__init__() - - # The ScaleGrad module is there to prevent the gradients - # w.r.t. the weight or bias of the first Conv2d module in self.conv from - # exceeding the range of fp16 when using automatic mixed precision (amp) - # training. (The second one is necessary to stop its bias from getting - # a too-large gradient). - - self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, - out_channels=layer1_channels, - kernel_size=3, - padding=(0, 1), # (time, freq) - ), - ScaleGrad(0.2), - Balancer(layer1_channels, channel_dim=1, max_abs=1.0), - SwooshR(), - nn.Conv2d( - in_channels=layer1_channels, - out_channels=layer2_channels, - kernel_size=3, - stride=2, - padding=0, - ), - Balancer(layer2_channels, channel_dim=1, max_abs=4.0), - SwooshR(), - nn.Conv2d( - in_channels=layer2_channels, - out_channels=layer3_channels, - kernel_size=3, - stride=(1, 2), # (time, freq) - ), - Balancer(layer3_channels, channel_dim=1, max_abs=4.0), - SwooshR(), - ) - - # just one convnext layer - self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7)) - - self.out_width = (((in_channels - 1) // 2) - 1) // 2 - self.layer3_channels = layer3_channels - - self.out = nn.Linear(self.out_width * layer3_channels, out_channels) - # use a larger than normal grad_scale on this whitening module; there is - # only one such module, so there is not a concern about adding together - # many copies of this extra gradient term. - self.out_whiten = Whiten( - num_groups=1, - whitening_limit=ScheduledFloat( - (0.0, 4.0), (20000.0, 8.0), default=4.0 - ), - prob=(0.025, 0.25), - grad_scale=0.02, - ) - - # max_log_eps=0.0 is to prevent both eps and the output of self.out from - # getting large, there is an unnecessary degree of freedom. - self.out_norm = BiasNorm(out_channels) - self.dropout = Dropout3(dropout, shared_dim=1) - - def forward( - self, x: torch.Tensor, x_lens: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Subsample x. - - Args: - x: - Its shape is (N, T, idim). - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - - Returns: - - a tensor of shape (N, ((T-1)//2 - 1)//2, odim) - - output lengths, of shape (batch_size,) - """ - # On entry, x is (N, T, idim) - x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) - # scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision) - # training, since the weights in the first convolution are otherwise the limiting factor for getting infinite - # gradients. - x = self.conv(x) - x = self.convnext(x) - - # Now x is of shape (N, odim, ((T-3)//2 - 1)//2, ((idim-1)//2 - 1)//2) - b, c, t, f = x.size() - - x = x.transpose(1, 2).reshape(b, t, c * f) - # now x: (N, ((T-1)//2 - 1))//2, out_width * layer3_channels)) - - x = self.out(x) - # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) - x = self.out_whiten(x) - x = self.out_norm(x) - x = self.dropout(x) - - if torch.jit.is_scripting(): - x_lens = (x_lens - 7) // 2 - else: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - x_lens = (x_lens - 7) // 2 - assert x.size(1) == x_lens.max().item() - - return x, x_lens - - def streaming_forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - cached_left_pad: Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Subsample x. - - Args: - x: - Its shape is (N, T, idim). - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - - Returns: - - a tensor of shape (N, ((T-1)//2 - 1)//2, odim) - - output lengths, of shape (batch_size,) - - updated cache - """ - # On entry, x is (N, T, idim) - x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) - - # T' = (T-7)//2 - x = self.conv(x) - - # T' = (T-7)//2-3 - x, cached_left_pad = self.convnext.streaming_forward( - x, cached_left_pad=cached_left_pad - ) - - # Now x is of shape (N, odim, T', ((idim-1)//2 - 1)//2) - b, c, t, f = x.size() - - x = x.transpose(1, 2).reshape(b, t, c * f) - # now x: (N, T', out_width * layer3_channels)) - - x = self.out(x) - # Now x is of shape (N, T', odim) - x = self.out_norm(x) - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - assert self.convnext.padding[0] == 3 - # The ConvNeXt module needs 3 frames of right padding after subsampling - x_lens = (x_lens - 7) // 2 - 3 - else: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - # The ConvNeXt module needs 3 frames of right padding after subsampling - assert self.convnext.padding[0] == 3 - x_lens = (x_lens - 7) // 2 - 3 - - assert x.size(1) == x_lens.max().item() - - return x, x_lens, cached_left_pad - - @torch.jit.export - def get_init_states( - self, - batch_size: int = 1, - device: torch.device = torch.device("cpu"), - ) -> Tensor: - """Get initial states for Conv2dSubsampling module. - It is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - """ - left_pad = self.convnext.padding[0] - freq = self.out_width - channels = self.layer3_channels - cached_embed_left_pad = torch.zeros( - batch_size, channels, left_pad, freq - ).to(device) - - return cached_embed_left_pad diff --git a/egs/iwslt22_ta/ST/zipformer/subsampling.py b/egs/iwslt22_ta/ST/zipformer/subsampling.py new file mode 120000 index 000000000..01ae9002c --- /dev/null +++ b/egs/iwslt22_ta/ST/zipformer/subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/iwslt22_ta/ST/zipformer/train.py b/egs/iwslt22_ta/ST/zipformer/train.py index 9788220c9..4b661d93e 100755 --- a/egs/iwslt22_ta/ST/zipformer/train.py +++ b/egs/iwslt22_ta/ST/zipformer/train.py @@ -1,9 +1,5 @@ #!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo, -# Zengwei Yao, -# Daniel Povey) +# Copyright 2021-2023 Johns Hopkins University (authors: Amir Hussein) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -23,27 +19,58 @@ Usage: export CUDA_VISIBLE_DEVICES="0,1,2,3" -# For non-streaming model training: -./zipformer/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --full-libri 1 \ - --max-duration 1000 -# For streaming model training: +# medium model 42.5M +# # For streaming model training: ./zipformer/train.py \ --world-size 4 \ - --num-epochs 30 \ + --num-epochs 20 \ --start-epoch 1 \ --use-fp16 1 \ - --exp-dir zipformer/exp \ + --exp-dir zipformer/exp-st-medium2 \ --causal 1 \ - --full-libri 1 \ - --max-duration 1000 + --num-encoder-layers 2,2,2,2,2,2 \ + --feedforward-dim 512,768,1024,1536,1024,768 \ + --encoder-dim 192,256,384,512,384,256 \ + --encoder-unmasked-dim 192,192,256,256,256,192 \ + --max-duration 300 \ + --prune-range 10 \ + --use-hat False +# # For offline model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 20 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp-st-medium2 \ + --causal 0 \ + --num-encoder-layers 2,2,2,2,2,2 \ + --feedforward-dim 512,768,1024,1536,1024,768 \ + --encoder-dim 192,256,384,512,384,256 \ + --encoder-unmasked-dim 192,192,256,256,256,192 \ + --max-duration 300 \ + --prune-range 10 \ + --use-hat False + +# medium model 42.5M + ./zipformer/train.py \ + --world-size 4 \ + --num-epochs 20 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp-st-medium-nohat800s-warmstep8k_baselr05_lrbatch5k_lrepoch6 \ + --causal 0 \ + --num-encoder-layers 2,2,2,2,2,2 \ + --feedforward-dim 512,768,1024,1536,1024,768 \ + --encoder-dim 192,256,384,512,384,256 \ + --encoder-unmasked-dim 192,192,256,256,256,192 \ + --max-duration 800 \ + --prune-range 10 \ + --warm-step 8000 \ + --lr-epochs 6 \ + --base-lr 0.055 \ + --use-hat False """ @@ -61,7 +88,7 @@ import sentencepiece as spm import torch import torch.multiprocessing as mp import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule +from asr_datamodule import IWSLTDialectSTDataModule from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut @@ -202,14 +229,14 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--decoder-dim", type=int, - default=512, + default=256, help="Embedding dimension in the decoder model.", ) parser.add_argument( "--joiner-dim", type=int, - default=512, + default=256, help="""Dimension used in the joiner model. Outputs from the encoder and decoder model are projected to this dimension before adding. @@ -222,6 +249,12 @@ def add_model_arguments(parser: argparse.ArgumentParser): default=False, help="If True, use causal version of model.", ) + parser.add_argument( + "--use-hat", + type=str2bool, + default=False, + help="If True, use HAT loss.", + ) parser.add_argument( "--chunk-size", @@ -302,12 +335,12 @@ def get_parser(): files, e.g., checkpoints, log, etc, are saved """, ) - + parser.add_argument( - "--bpe-model", + "--bpe-tgt-model", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_en_1000/bpe.model", + help="Path to target data BPE model", ) parser.add_argument( @@ -317,7 +350,7 @@ def get_parser(): parser.add_argument( "--lr-batches", type=float, - default=7500, + default=5000, help="""Number of steps that affects how rapidly the learning rate decreases. We suggest not to change this.""", ) @@ -325,7 +358,7 @@ def get_parser(): parser.add_argument( "--lr-epochs", type=float, - default=3.5, + default=4, help="""Number of epochs that affects how rapidly the learning rate decreases. """, ) @@ -384,6 +417,12 @@ def get_parser(): default=42, help="The seed for random generators intended for reproducibility", ) + parser.add_argument( + "--warm-step", + type=int, + default=20000, + help="warmup steps", + ) parser.add_argument( "--print-diagnostics", @@ -402,7 +441,7 @@ def get_parser(): parser.add_argument( "--save-every-n", type=int, - default=4000, + default=16000, help="""Save checkpoint after processing this number of batches" periodically. We save checkpoint to exp-dir/ whenever params.batch_idx_train % save_every_n == 0. The checkpoint filename @@ -415,7 +454,7 @@ def get_parser(): parser.add_argument( "--keep-last-k", type=int, - default=30, + default=10, help="""Only keep this number of checkpoints on disk. For instance, if it is 3, there are only 3 checkpoints in the exp-dir with filenames `checkpoint-xxx.pt`. @@ -439,7 +478,7 @@ def get_parser(): parser.add_argument( "--use-fp16", type=str2bool, - default=False, + default=True, help="Whether to use half precision training.", ) @@ -502,11 +541,11 @@ def get_params() -> AttributeDict: "batch_idx_train": 0, "log_interval": 50, "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 + "valid_interval": 1000, # For the 100h subset, use 800 # parameters for zipformer "feature_dim": 80, "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, + "warm_step": 10000, "env_info": get_env_info(), } ) @@ -593,6 +632,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, + use_hat=params.use_hat, ) return model @@ -716,7 +756,7 @@ def save_checkpoint( def compute_loss( params: AttributeDict, model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, + sp_tgt: spm.SentencePieceProcessor, batch: dict, is_training: bool, ) -> Tuple[Tensor, MetricsTracker]: @@ -751,18 +791,38 @@ def compute_loss( warm_step = params.warm_step texts = batch["supervisions"]["text"] + tgt_texts = batch["supervisions"]["tgt_text"]['eng'] y = sp.encode(texts, out_type=int) + y_tgt = sp_tgt.encode(tgt_texts, out_type=int) y = k2.RaggedTensor(y).to(device) + y_tgt = k2.RaggedTensor(y_tgt).to(device) with torch.set_grad_enabled(is_training): simple_loss, pruned_loss = model( x=feature, x_lens=feature_lens, - y=y, + y=y_tgt, prune_range=params.prune_range, am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss_is_finite = torch.isfinite(simple_loss) + pruned_loss_is_finite = torch.isfinite(pruned_loss) + is_finite = simple_loss_is_finite & pruned_loss_is_finite + inf_flag = False + if not torch.all(is_finite): + inf_flag = True + logging.info( + "Not all losses are finite!\n" + f"simple_loss: {simple_loss}\n" + f"pruned_loss: {pruned_loss}" + ) + display_and_save_batch(batch, params=params, sp=sp, sp_tgt=sp_tgt) + simple_loss = simple_loss[simple_loss_is_finite] + pruned_loss = pruned_loss[pruned_loss_is_finite] + + simple_loss = simple_loss.sum() + pruned_loss = pruned_loss.sum() s = params.simple_loss_scale # take down the scale on the simple loss from 1.0 at the start @@ -792,13 +852,13 @@ def compute_loss( info["simple_loss"] = simple_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item() - return loss, info + return loss, info, inf_flag def compute_validation_loss( params: AttributeDict, model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, + sp_tgt: spm.SentencePieceProcessor, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, ) -> MetricsTracker: @@ -808,10 +868,10 @@ def compute_validation_loss( tot_loss = MetricsTracker() for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( + loss, loss_info, _ = compute_loss( params=params, model=model, - sp=sp, + sp_tgt=sp_tgt, batch=batch, is_training=False, ) @@ -834,7 +894,7 @@ def train_one_epoch( model: Union[nn.Module, DDP], optimizer: torch.optim.Optimizer, scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, + sp_tgt: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, scaler: GradScaler, @@ -902,10 +962,10 @@ def train_one_epoch( try: with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( + loss, loss_info, inf_flag = compute_loss( params=params, model=model, - sp=sp, + sp_tgt=sp_tgt, batch=batch, is_training=True, ) @@ -914,12 +974,15 @@ def train_one_epoch( # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) + if not inf_flag: + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + else: + continue except: # noqa save_bad_model() display_and_save_batch(batch, params=params, sp=sp) @@ -1012,6 +1075,7 @@ def train_one_epoch( params=params, model=model, sp=sp, + sp_tgt=sp_tgt, valid_dl=valid_dl, world_size=world_size, ) @@ -1064,8 +1128,8 @@ def run(rank, world_size, args): device = torch.device("cuda", rank) logging.info(f"Device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + sp_tgt = spm.SentencePieceProcessor() + sp_tgt.load(params.bpe_tgt_model) # is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") @@ -1124,12 +1188,8 @@ def run(rank, world_size, args): if params.inf_check: register_inf_check_hooks(model) - librispeech = LibriSpeechAsrDataModule(args) - - train_cuts = librispeech.train_clean_100_cuts() - if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + iwslt_ta = IWSLTDialectSTDataModule(args) + train_cuts = iwslt_ta.train_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds @@ -1140,35 +1200,41 @@ def run(rank, world_size, args): # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold - if c.duration < 1.0 or c.duration > 20.0: - # logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - # ) + if c.duration < 0.1 or c.duration > 30.0: + #logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + #) + return False + if c.supervisions == []: return False - # In pruned RNN-T, we require that T >= S # where T is the number of feature frames after subsampling # and S is the number of tokens in the utterance - # In ./zipformer.py, the conv module uses the following expression + # In ./conformer.py, the conv module uses the following expression # for subsampling - T = ((c.num_frames - 7) // 2 + 1) // 2 - tokens = sp.encode(c.supervisions[0].text, out_type=str) + T = ((c.num_frames - 1) // 2 - 1) // 2 + tokens = sp_tgt.encode(c.supervisions[0].custom['translated_text']['eng'], out_type=str) if T < len(tokens): - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) + # logging.warning( + # f"Exclude cut with ID {c.id} from training. " + # f"Number of frames (before subsampling): {c.num_frames}. " + # f"Number of frames (after subsampling): {T}. " + # f"Text: {c.supervisions[0].text}. " + # f"Tokens: {tokens}. " + # f"Number of tokens: {len(tokens)}" + # ) return False return True + def remove_short_and_long_text(c: Cut): + # Keep only text with charachters between 20 and 400 + + return 3 <= len(c.supervisions[0].custom['translated_text']['eng']) <= 400 train_cuts = train_cuts.filter(remove_short_and_long_utt) + # train_cuts = train_cuts.filter(remove_short_and_long_text) if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: # We only load the sampler's state dict when it loads a checkpoint @@ -1177,20 +1243,19 @@ def run(rank, world_size, args): else: sampler_state_dict = None - train_dl = librispeech.train_dataloaders( + train_dl = iwslt_ta.train_dataloaders( train_cuts, sampler_state_dict=sampler_state_dict ) - valid_cuts = librispeech.dev_clean_cuts() - valid_cuts += librispeech.dev_other_cuts() - valid_dl = librispeech.valid_dataloaders(valid_cuts) + valid_cuts = iwslt_ta.dev_cuts() + valid_dl = iwslt_ta.test_dataloaders(valid_cuts) if not params.print_diagnostics: scan_pessimistic_batches_for_oom( model=model, train_dl=train_dl, optimizer=optimizer, - sp=sp, + sp_tgt=sp_tgt, params=params, ) @@ -1215,7 +1280,7 @@ def run(rank, world_size, args): model_avg=model_avg, optimizer=optimizer, scheduler=scheduler, - sp=sp, + sp_tgt=sp_tgt, train_dl=train_dl, valid_dl=valid_dl, scaler=scaler, @@ -1250,6 +1315,7 @@ def display_and_save_batch( batch: dict, params: AttributeDict, sp: spm.SentencePieceProcessor, + sp_tgt: spm.SentencePieceProcessor, ) -> None: """Display the batch statistics and save the batch into disk. @@ -1274,7 +1340,8 @@ def display_and_save_batch( logging.info(f"features shape: {features.shape}") y = sp.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) + y_tgt = sp_tgt.encode(supervisions["tgt_text"], out_type=int) + num_tokens = sum(len(i) for i in y_tgt) logging.info(f"num tokens: {num_tokens}") @@ -1282,7 +1349,7 @@ def scan_pessimistic_batches_for_oom( model: Union[nn.Module, DDP], train_dl: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, + sp_tgt: spm.SentencePieceProcessor, params: AttributeDict, ): from lhotse.dataset import find_pessimistic_batches @@ -1295,10 +1362,10 @@ def scan_pessimistic_batches_for_oom( batch = train_dl.dataset[cuts] try: with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( + loss, _, _ = compute_loss( params=params, model=model, - sp=sp, + sp_tgt=sp_tgt, batch=batch, is_training=True, ) @@ -1313,7 +1380,7 @@ def scan_pessimistic_batches_for_oom( f"Failing criterion: {criterion} " f"(={crit_values[criterion]}) ..." ) - display_and_save_batch(batch, params=params, sp=sp) + display_and_save_batch(batch, params=params, sp=sp, sp_tgt=sp_tgt) raise logging.info( f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" @@ -1322,10 +1389,9 @@ def scan_pessimistic_batches_for_oom( def main(): parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) + IWSLTDialectSTDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) - world_size = args.world_size assert world_size >= 1 if world_size > 1: diff --git a/egs/iwslt22_ta/ST/zipformer/train_asr.py b/egs/iwslt22_ta/ST/zipformer/train_asr.py deleted file mode 100755 index 5acae85c6..000000000 --- a/egs/iwslt22_ta/ST/zipformer/train_asr.py +++ /dev/null @@ -1,1417 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2023 Johns Hopkins University (authors: Amir Hussein) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -# For non-streaming model training: -./zipformer/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --full-libri 1 \ - --max-duration 1000 - -# For streaming model training: -./zipformer/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp \ - --causal 1 \ - --full-libri 1 \ - --max-duration 1000 - -# small model 28.5M -./zipformer/train_asr.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp-asr-small \ - --causal 0 \ - --num-encoder-layers 2,2,2,2,2,2 \ - --feedforward-dim 256,512,768,1024,768,512 \ - --encoder-dim 128,256,256,512,256,256 \ - --encoder-unmasked-dim 64,128,128,256,128,128 \ - --base-lr 0.01 \ - --max-duration 1000 - -# medium model 42.5M -./zipformer/train_asr.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp-asr-small2 \ - --causal 0 \ - --num-encoder-layers 2,2,2,2,2,2 \ - --feedforward-dim 512,768,1024,1536,1024,768 \ - --encoder-dim 192,256,384,512,384,256 \ - --encoder-unmasked-dim 192,192,256,256,256,192 \ - --max-duration 800 - - -# large model 148.8M -./zipformer/train_asr.py \ - --world-size 4 \ - --num-epochs 40 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp-asr-large \ - --causal 0 \ - --num-encoder-layers 2,2,4,5,4,2 \ - --feedforward-dim 512,768,1536,2048,1536,768 \ - --encoder-dim 192,256,512,768,512,256 \ - --encoder-unmasked-dim 192,192,256,320,256,192 \ - --max-duration 300 -""" - - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import IWSLTDialectSTDataModule -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import Transducer -from optim import Eden, ScaledAdam -from scaling import ScheduledFloat -from subsampling import Conv2dSubsampling -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer2 - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.hooks import register_inf_check_hooks -from icefall.utils import ( - AttributeDict, - MetricsTracker, - get_parameter_groups_with_lrs, - setup_logger, - str2bool, -) - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def get_adjusted_batch_count(params: AttributeDict) -> float: - # returns the number of batches we would have used so far if we had used the reference - # duration. This is for purposes of set_batch_count(). - return ( - params.batch_idx_train - * (params.max_duration * params.world_size) - / params.ref_duration - ) - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for name, module in model.named_modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - if hasattr(module, "name"): - module.name = name - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,2,3,4,3,2", - help="Number of zipformer encoder layers per stack, comma separated.", - ) - - parser.add_argument( - "--downsampling-factor", - type=str, - default="1,2,4,8,4,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--feedforward-dim", - type=str, - default="512,768,1024,1536,1024,768", - help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", - ) - - parser.add_argument( - "--num-heads", - type=str, - default="4,4,4,8,4,4", - help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", - ) - - parser.add_argument( - "--encoder-dim", - type=str, - default="192,256,384,512,384,256", - help="Embedding dimension in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--query-head-dim", - type=str, - default="32", - help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--value-head-dim", - type=str, - default="12", - help="Value dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-head-dim", - type=str, - default="4", - help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-dim", - type=int, - default="48", - help="Positional-encoding embedding dimension", - ) - - parser.add_argument( - "--encoder-unmasked-dim", - type=str, - default="192,192,256,256,256,192", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", - ) - - parser.add_argument( - "--cnn-module-kernel", - type=str, - default="31,31,15,15,15,31", - help="Sizes of convolutional kernels in convolution modules in each encoder stack: " - "a single int or comma-separated list.", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=256, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=256, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - parser.add_argument( - "--causal", - type=str2bool, - default=False, - help="If True, use causal version of model.", - ) - - parser.add_argument( - "--chunk-size", - type=str, - default="16,32,64,-1", - help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " - " Must be just -1 if --causal=False", - ) - - parser.add_argument( - "--left-context-frames", - type=str, - default="64,128,256,-1", - help="Maximum left-contexts for causal training, measured in frames which will " - "be converted to a number of chunks. If splitting into chunks, " - "chunk left-context frames will be chosen randomly from this list; else not relevant.", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_ta_1000/bpe.model", - help="Path to source data BPE model", - ) - parser.add_argument( - "--bpe-tgt-model", - type=str, - default="data/lang_bpe_en_1000/bpe.model", - help="Path to target data BPE model", - ) - - parser.add_argument( - "--base-lr", type=float, default=0.045, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=4, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--ref-duration", - type=float, - default=600, - help="Reference batch duration for purposes of adjusting batch counts for setting various " - "schedules inside the model", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" "part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=16000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=3, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=True, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 1000, # For the 100h subset, use 800 - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 20000, - "env_info": get_env_info(), - } - ) - - return params - - -def _to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - -def get_encoder_embed(params: AttributeDict) -> nn.Module: - # encoder_embed converts the input of shape (N, T, num_features) - # to the shape (N, (T - 7) // 2, encoder_dims). - # That is, it does two things simultaneously: - # (1) subsampling: T -> (T - 7) // 2 - # (2) embedding: num_features -> encoder_dims - # In the normal configuration, we will downsample once more at the end - # by a factor of 2, and most of the encoder stacks will run at a lower - # sampling rate. - encoder_embed = Conv2dSubsampling( - in_channels=params.feature_dim, - out_channels=_to_int_tuple(params.encoder_dim)[0], - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - ) - return encoder_embed - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Zipformer2( - output_downsampling_factor=2, - downsampling_factor=_to_int_tuple(params.downsampling_factor), - num_encoder_layers=_to_int_tuple(params.num_encoder_layers), - encoder_dim=_to_int_tuple(params.encoder_dim), - encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), - query_head_dim=_to_int_tuple(params.query_head_dim), - pos_head_dim=_to_int_tuple(params.pos_head_dim), - value_head_dim=_to_int_tuple(params.value_head_dim), - pos_dim=params.pos_dim, - num_heads=_to_int_tuple(params.num_heads), - feedforward_dim=_to_int_tuple(params.feedforward_dim), - cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - warmup_batches=4000.0, - causal=params.causal, - chunk_size=_to_int_tuple(params.chunk_size), - left_context_frames=_to_int_tuple(params.left_context_frames), - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=max(_to_int_tuple(params.encoder_dim)), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder_embed = get_encoder_embed(params) - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder_embed=encoder_embed, - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=int(max(params.encoder_dim.split(","))), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - sp_tgt: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute CTC loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - texts = batch["supervisions"]["text"] - tgt_texts = batch["supervisions"]["tgt_text"] - y = sp.encode(texts, out_type=int) - y_tgt = sp_tgt.encode(tgt_texts, out_type=int) - y = k2.RaggedTensor(y).to(device) - y_tgt = k2.RaggedTensor(y_tgt).to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - simple_loss_is_finite = torch.isfinite(simple_loss) - pruned_loss_is_finite = torch.isfinite(pruned_loss) - is_finite = simple_loss_is_finite & pruned_loss_is_finite - inf_flag = False - if not torch.all(is_finite): - inf_flag = True - logging.info( - "Not all losses are finite!\n" - f"simple_loss: {simple_loss}\n" - f"pruned_loss: {pruned_loss}" - ) - display_and_save_batch(batch, params=params, sp=sp) - simple_loss = simple_loss[simple_loss_is_finite] - pruned_loss = pruned_loss[pruned_loss_is_finite] - - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - - loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - sp_tgt: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - sp_tgt=sp_tgt, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, - sp_tgt: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - saved_bad_model = False - - def save_bad_model(suffix: str = ""): - save_checkpoint_impl( - filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=0, - ) - - for batch_idx, batch in enumerate(train_dl): - if batch_idx % 10 == 0: - set_batch_count(model, get_adjusted_batch_count(params)) - - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - sp_tgt=sp_tgt, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - save_bad_model() - display_and_save_batch(batch, params=params, sp=sp) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - - if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - if not saved_bad_model: - save_bad_model(suffix="-first-warning") - saved_bad_model = True - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - save_bad_model() - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) - - if batch_idx % params.log_interval == 0: - cur_lr = max(scheduler.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - sp_tgt=sp_tgt, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp_tgt = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - sp_tgt.load(params.bpe_tgt_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - optimizer = ScaledAdam( - get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), - lr=params.base_lr, # should have no effect - clipping_scale=2.0, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 2**22 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - iwslt_ta = IWSLTDialectSTDataModule(args) - train_cuts = iwslt_ta.train_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 0.3 or c.duration > 30.0: - #logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - #) - return False - if c.supervisions == []: - return False - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./conformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 1) // 2 - 1) // 2 - tokens = sp.encode(c.supervisions[0].text, out_type=str) - - if T < len(tokens): - # logging.warning( - # f"Exclude cut with ID {c.id} from training. " - # f"Number of frames (before subsampling): {c.num_frames}. " - # f"Number of frames (after subsampling): {T}. " - # f"Text: {c.supervisions[0].text}. " - # f"Tokens: {tokens}. " - # f"Number of tokens: {len(tokens)}" - # ) - return False - - return True - - def remove_short_and_long_text(c: Cut): - # Keep only text with charachters between 20 and 400 - - return 3 <= len(c.supervisions[0].text) <= 400 - train_cuts = train_cuts.filter(remove_short_and_long_utt) - train_cuts = train_cuts.filter(remove_short_and_long_text) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = iwslt_ta.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = iwslt_ta.dev_cuts() - valid_dl = iwslt_ta.test_dataloaders(valid_cuts) - - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - sp_tgt=sp_tgt, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - sp_tgt=sp_tgt, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: spm.SentencePieceProcessor, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = sp.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, - sp_tgt: spm.SentencePieceProcessor, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - sp_tgt=sp_tgt, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, sp=sp) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - parser = get_parser() - IWSLTDialectSTDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/iwslt22_ta/ST/zipformer/train_st.py b/egs/iwslt22_ta/ST/zipformer/train_st.py deleted file mode 100755 index d6d60d9f4..000000000 --- a/egs/iwslt22_ta/ST/zipformer/train_st.py +++ /dev/null @@ -1,1422 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2023 Johns Hopkins University (authors: Amir Hussein) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -# # For non-streaming model training: -# ./zipformer/train.py \ -# --world-size 4 \ -# --num-epochs 30 \ -# --start-epoch 1 \ -# --use-fp16 1 \ -# --exp-dir zipformer/exp \ -# --full-libri 1 \ -# --max-duration 1000 - -# # For streaming model training: -# ./zipformer/train.py \ -# --world-size 4 \ -# --num-epochs 30 \ -# --start-epoch 1 \ -# --use-fp16 1 \ -# --exp-dir zipformer/exp \ -# --causal 1 \ -# --full-libri 1 \ -# --max-duration 1000 - -# small model 28.5M -./zipformer/train_st.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp-st-small \ - --causal 0 \ - --num-encoder-layers 2,2,2,2,2,2 \ - --feedforward-dim 256,512,768,1024,768,512 \ - --encoder-dim 128,256,256,512,256,256 \ - --encoder-unmasked-dim 64,128,128,256,128,128 \ - --base-lr 0.01 \ - --max-duration 500 - -# medium model 42.5M -./zipformer/train_st.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp-st-small2 \ - --causal 0 \ - --num-encoder-layers 2,2,2,2,2,2 \ - --feedforward-dim 512,768,1024,1536,1024,768 \ - --encoder-dim 192,256,384,512,384,256 \ - --encoder-unmasked-dim 192,192,256,256,256,192 \ - --max-duration 800 - -# large model 148.8M -./zipformer/train_asr.py \ - --world-size 4 \ - --num-epochs 40 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir zipformer/exp-large \ - --causal 0 \ - --num-encoder-layers 2,2,4,5,4,2 \ - --feedforward-dim 512,768,1536,2048,1536,768 \ - --encoder-dim 192,256,512,768,512,256 \ - --encoder-unmasked-dim 192,192,256,320,256,192 \ - --max-duration 300 -""" - - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import IWSLTDialectSTDataModule -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import Transducer -from optim import Eden, ScaledAdam -from scaling import ScheduledFloat -from subsampling import Conv2dSubsampling -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer2 - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.hooks import register_inf_check_hooks -from icefall.utils import ( - AttributeDict, - MetricsTracker, - get_parameter_groups_with_lrs, - setup_logger, - str2bool, -) - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def get_adjusted_batch_count(params: AttributeDict) -> float: - # returns the number of batches we would have used so far if we had used the reference - # duration. This is for purposes of set_batch_count(). - return ( - params.batch_idx_train - * (params.max_duration * params.world_size) - / params.ref_duration - ) - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for name, module in model.named_modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - if hasattr(module, "name"): - module.name = name - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,2,3,4,3,2", - help="Number of zipformer encoder layers per stack, comma separated.", - ) - - parser.add_argument( - "--downsampling-factor", - type=str, - default="1,2,4,8,4,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--feedforward-dim", - type=str, - default="512,768,1024,1536,1024,768", - help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", - ) - - parser.add_argument( - "--num-heads", - type=str, - default="4,4,4,8,4,4", - help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", - ) - - parser.add_argument( - "--encoder-dim", - type=str, - default="192,256,384,512,384,256", - help="Embedding dimension in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--query-head-dim", - type=str, - default="32", - help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--value-head-dim", - type=str, - default="12", - help="Value dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-head-dim", - type=str, - default="4", - help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-dim", - type=int, - default="48", - help="Positional-encoding embedding dimension", - ) - - parser.add_argument( - "--encoder-unmasked-dim", - type=str, - default="192,192,256,256,256,192", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", - ) - - parser.add_argument( - "--cnn-module-kernel", - type=str, - default="31,31,15,15,15,31", - help="Sizes of convolutional kernels in convolution modules in each encoder stack: " - "a single int or comma-separated list.", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=256, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=256, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - parser.add_argument( - "--causal", - type=str2bool, - default=False, - help="If True, use causal version of model.", - ) - - parser.add_argument( - "--chunk-size", - type=str, - default="16,32,64,-1", - help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " - " Must be just -1 if --causal=False", - ) - - parser.add_argument( - "--left-context-frames", - type=str, - default="64,128,256,-1", - help="Maximum left-contexts for causal training, measured in frames which will " - "be converted to a number of chunks. If splitting into chunks, " - "chunk left-context frames will be chosen randomly from this list; else not relevant.", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_ta_1000/bpe.model", - help="Path to source data BPE model", - ) - parser.add_argument( - "--bpe-tgt-model", - type=str, - default="data/lang_bpe_en_1000/bpe.model", - help="Path to target data BPE model", - ) - - parser.add_argument( - "--base-lr", type=float, default=0.045, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=4, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--ref-duration", - type=float, - default=600, - help="Reference batch duration for purposes of adjusting batch counts for setting various " - "schedules inside the model", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" "part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=16000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=10, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=True, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 1000, # For the 100h subset, use 800 - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 20000, - "env_info": get_env_info(), - } - ) - - return params - - -def _to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - -def get_encoder_embed(params: AttributeDict) -> nn.Module: - # encoder_embed converts the input of shape (N, T, num_features) - # to the shape (N, (T - 7) // 2, encoder_dims). - # That is, it does two things simultaneously: - # (1) subsampling: T -> (T - 7) // 2 - # (2) embedding: num_features -> encoder_dims - # In the normal configuration, we will downsample once more at the end - # by a factor of 2, and most of the encoder stacks will run at a lower - # sampling rate. - encoder_embed = Conv2dSubsampling( - in_channels=params.feature_dim, - out_channels=_to_int_tuple(params.encoder_dim)[0], - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - ) - return encoder_embed - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Zipformer2( - output_downsampling_factor=2, - downsampling_factor=_to_int_tuple(params.downsampling_factor), - num_encoder_layers=_to_int_tuple(params.num_encoder_layers), - encoder_dim=_to_int_tuple(params.encoder_dim), - encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), - query_head_dim=_to_int_tuple(params.query_head_dim), - pos_head_dim=_to_int_tuple(params.pos_head_dim), - value_head_dim=_to_int_tuple(params.value_head_dim), - pos_dim=params.pos_dim, - num_heads=_to_int_tuple(params.num_heads), - feedforward_dim=_to_int_tuple(params.feedforward_dim), - cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - warmup_batches=4000.0, - causal=params.causal, - chunk_size=_to_int_tuple(params.chunk_size), - left_context_frames=_to_int_tuple(params.left_context_frames), - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=max(_to_int_tuple(params.encoder_dim)), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder_embed = get_encoder_embed(params) - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder_embed=encoder_embed, - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=int(max(params.encoder_dim.split(","))), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - sp_tgt: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute CTC loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - tgt_texts = batch["supervisions"]["tgt_text"] - y = sp.encode(texts, out_type=int) - y_tgt = sp_tgt.encode(tgt_texts, out_type=int) - y = k2.RaggedTensor(y).to(device) - y_tgt = k2.RaggedTensor(y_tgt).to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y_tgt, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - simple_loss_is_finite = torch.isfinite(simple_loss) - pruned_loss_is_finite = torch.isfinite(pruned_loss) - is_finite = simple_loss_is_finite & pruned_loss_is_finite - inf_flag = False - if not torch.all(is_finite): - inf_flag = True - logging.info( - "Not all losses are finite!\n" - f"simple_loss: {simple_loss}\n" - f"pruned_loss: {pruned_loss}" - ) - display_and_save_batch(batch, params=params, sp=sp, sp_tgt=sp_tgt) - simple_loss = simple_loss[simple_loss_is_finite] - pruned_loss = pruned_loss[pruned_loss_is_finite] - - simple_loss = simple_loss.sum() - pruned_loss = pruned_loss.sum() - - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - - loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info, inf_flag - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - sp_tgt: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info, _ = compute_loss( - params=params, - model=model, - sp=sp, - sp_tgt=sp_tgt, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, - sp_tgt: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - saved_bad_model = False - - def save_bad_model(suffix: str = ""): - save_checkpoint_impl( - filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=0, - ) - - for batch_idx, batch in enumerate(train_dl): - if batch_idx % 10 == 0: - set_batch_count(model, get_adjusted_batch_count(params)) - - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info, inf_flag = compute_loss( - params=params, - model=model, - sp=sp, - sp_tgt=sp_tgt, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - if not inf_flag: - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - else: - continue - except: # noqa - save_bad_model() - display_and_save_batch(batch, params=params, sp=sp) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - - if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - if not saved_bad_model: - save_bad_model(suffix="-first-warning") - saved_bad_model = True - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - save_bad_model() - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) - - if batch_idx % params.log_interval == 0: - cur_lr = max(scheduler.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - sp_tgt=sp_tgt, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp_tgt = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - sp_tgt.load(params.bpe_tgt_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - optimizer = ScaledAdam( - get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), - lr=params.base_lr, # should have no effect - clipping_scale=2.0, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 2**22 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - iwslt_ta = IWSLTDialectSTDataModule(args) - train_cuts = iwslt_ta.train_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 0.3 or c.duration > 30.0: - #logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - #) - return False - if c.supervisions == []: - return False - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./conformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 1) // 2 - 1) // 2 - tokens = sp.encode(c.supervisions[0].custom['tgt_text'], out_type=str) - - if T < len(tokens): - # logging.warning( - # f"Exclude cut with ID {c.id} from training. " - # f"Number of frames (before subsampling): {c.num_frames}. " - # f"Number of frames (after subsampling): {T}. " - # f"Text: {c.supervisions[0].text}. " - # f"Tokens: {tokens}. " - # f"Number of tokens: {len(tokens)}" - # ) - return False - - return True - - def remove_short_and_long_text(c: Cut): - # Keep only text with charachters between 20 and 400 - - return 3 <= len(c.supervisions[0].custom['tgt_text']) <= 400 - train_cuts = train_cuts.filter(remove_short_and_long_utt) - train_cuts = train_cuts.filter(remove_short_and_long_text) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = iwslt_ta.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = iwslt_ta.dev_cuts() - valid_dl = iwslt_ta.test_dataloaders(valid_cuts) - - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - sp_tgt=sp_tgt, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - sp_tgt=sp_tgt, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: spm.SentencePieceProcessor, - sp_tgt: spm.SentencePieceProcessor, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = sp.encode(supervisions["text"], out_type=int) - y_tgt = sp_tgt.encode(supervisions["tgt_text"], out_type=int) - num_tokens = sum(len(i) for i in y_tgt) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, - sp_tgt: spm.SentencePieceProcessor, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _, _ = compute_loss( - params=params, - model=model, - sp=sp, - sp_tgt=sp_tgt, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, sp=sp, sp_tgt=sp_tgt) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - parser = get_parser() - IWSLTDialectSTDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/iwslt22_ta/ST/zipformer/zipformer.py b/egs/iwslt22_ta/ST/zipformer/zipformer.py deleted file mode 100644 index 8d90198fd..000000000 --- a/egs/iwslt22_ta/ST/zipformer/zipformer.py +++ /dev/null @@ -1,2237 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -import math -import warnings -from typing import List, Optional, Tuple, Union -import logging -import torch -import random -from encoder_interface import EncoderInterface -from scaling import ( - Balancer, - BiasNorm, - Dropout2, - ChunkCausalDepthwiseConv1d, - ActivationDropoutAndLinear, - ScaledLinear, # not as in other dirs.. just scales down initial parameter values. - Whiten, - Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. - penalize_abs_values_gt, - softmax, - ScheduledFloat, - FloatLike, - limit_param_value, - convert_num_channels, -) -from torch import Tensor, nn - - -class Zipformer2(EncoderInterface): - """ - Args: - - Note: all "int or Tuple[int]" arguments below will be treated as lists of the same length - as downsampling_factor if they are single ints or one-element tuples. The length of - downsampling_factor defines the number of stacks. - - output_downsampling_factor (int): how much to downsample at the output. Note: - we also downsample by a factor of 2 in the Conv2dSubsampling encoder. - You should probably leave this at 2. - downsampling_factor (Tuple[int]): downsampling factor for each encoder stack. - Note: this is in addition to the downsampling factor of 2 that is applied in - the frontend (self.encoder_embed). - encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per - encoder stack. - num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack - encoder_unmasked_dim (int or Tuple[int]): unmasked dimension in each of - the encoder stacks for purposes of per-frame dropout (recommend 256 for - now). - query_head_dim (int or Tuple[int]): dimension of query and key per attention - head: per stack, if a tuple.. - pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection per - attention head - value_head_dim (int or Tuple[int]): dimension of value in each attention head - num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism. - Must be at least 4. - feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules - cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module - - pos_dim (int): the dimension of each positional-encoding vector prior to projection, - e.g. 128. - - dropout (float): dropout rate - warmup_batches (float): number of batches to warm up over; this controls - dropout of encoder layers. - causal (bool): if True, support chunkwise causal convolution. This should - not hurt WER as no modeling power is lost, but the convolution modules will be - slightly slower and use more memory. Enables use of the chunk_size and - left_context_chunks options in forward(), which simulates streaming - decoding. - chunk_size: (list of int): only set this to other than [-1] if causal; - the chunk size will be randomly chosen from this list. -1 means no chunking. - left_context_frames: (list of int): determines the number of left- - context chunks for causal training; will be rounded to a number of - chunks. Must not be less than cnn_module_kernel (after factoring in - rounding and downsampling); an error will be thrown if this is violated. - """ - def __init__( - self, - output_downsampling_factor: int = 2, - downsampling_factor: Tuple[int] = (2, 4), - encoder_dim: Union[int, Tuple[int]] = 384, - num_encoder_layers: Union[int, Tuple[int]] = 4, - encoder_unmasked_dim: Union[int, Tuple[int]] = 256, - query_head_dim: Union[int, Tuple[int]] = 24, - pos_head_dim: Union[int, Tuple[int]] = 4, - value_head_dim: Union[int, Tuple[int]] = 12, - num_heads: Union[int, Tuple[int]] = 8, - feedforward_dim: Union[int, Tuple[int]] = 1536, - cnn_module_kernel: Union[int, Tuple[int]] = 31, - pos_dim: int = 192, - dropout: FloatLike = None, # see code below for default - warmup_batches: float = 4000.0, - causal: bool = False, - chunk_size: Tuple[int] = [-1], - left_context_frames: Tuple[int] = [-1], - ) -> None: - super(Zipformer2, self).__init__() - - if dropout is None: - dropout = ScheduledFloat((0.0, 0.3), - (20000.0, 0.1)) - - def _to_tuple(x): - """ Converts a single int or a 1-tuple of an int to a tuple with the same length - as downsampling_factor""" - if isinstance(x, int): - x = (x,) - if len(x) == 1: - x = x * len(downsampling_factor) - else: - assert len(x) == len(downsampling_factor) and isinstance(x[0], int) - return x - - self.output_downsampling_factor = output_downsampling_factor # int - self.downsampling_factor = downsampling_factor # tuple - self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple - self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple(encoder_unmasked_dim) # tuple - num_encoder_layers = _to_tuple(num_encoder_layers) - self.query_head_dim = query_head_dim = _to_tuple(query_head_dim) - self.value_head_dim = value_head_dim = _to_tuple(value_head_dim) - pos_head_dim = _to_tuple(pos_head_dim) - self.num_heads = num_heads = _to_tuple(num_heads) - feedforward_dim = _to_tuple(feedforward_dim) - self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) - - self.causal = causal - self.chunk_size = chunk_size - self.left_context_frames = left_context_frames - - for u,d in zip(encoder_unmasked_dim, encoder_dim): - assert u <= d - - # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder - encoders = [] - - num_encoders = len(downsampling_factor) - for i in range(num_encoders): - - encoder_layer = Zipformer2EncoderLayer( - embed_dim=encoder_dim[i], - pos_dim=pos_dim, - num_heads=num_heads[i], - query_head_dim=query_head_dim[i], - pos_head_dim=pos_head_dim[i], - value_head_dim=value_head_dim[i], - feedforward_dim=feedforward_dim[i], - dropout=dropout, - cnn_module_kernel=cnn_module_kernel[i], - causal=causal, - ) - - # For the segment of the warmup period, we let the Conv2dSubsampling - # layer learn something. Then we start to warm up the other encoders. - encoder = Zipformer2Encoder( - encoder_layer, - num_encoder_layers[i], - pos_dim=pos_dim, - dropout=dropout, - warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), - warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), - final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5), - ) - - if downsampling_factor[i] != 1: - encoder = DownsampledZipformer2Encoder( - encoder, - dim=encoder_dim[i], - downsample=downsampling_factor[i], - dropout=dropout, - ) - - encoders.append(encoder) - - self.encoders = nn.ModuleList(encoders) - - self.downsample_output = SimpleDownsample(max(encoder_dim), - downsample=output_downsampling_factor, - dropout=dropout) - - def get_feature_masks( - self, - x: Tensor) -> Union[List[float], List[Tensor]]: - """ - In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of - randomized feature masks, one per encoder. - On e.g. 15% of frames, these masks will zero out all enocder dims larger than - some supplied number, e.g. >256, so in effect on those frames we are using - a smaller encoer dim. - - We generate the random masks at this level because we want the 2 masks to 'agree' - all the way up the encoder stack. This will mean that the 1st mask will have - mask values repeated self.zipformer_subsampling_factor times. - - Args: - x: the embeddings (needed for the shape and dtype and device), of shape - (1, batch_size, encoder_dims0) - """ - num_encoders = len(self.encoder_dim) - if not self.training: - return [ 1.0 ] * num_encoders - - (num_frames0, batch_size, _encoder_dims0) = x.shape - - assert self.encoder_dim[0] == _encoder_dims0 - - feature_mask_dropout_prob = 0.125 - - # mask1 shape: (1, batch_size, 1) - mask1 = (torch.rand(1, batch_size, 1, - device=x.device) > - feature_mask_dropout_prob).to(x.dtype) - - # mask2 has additional sequences masked, about twice the number. - mask2 = torch.logical_and(mask1, - (torch.rand(1, batch_size, 1, - device=x.device) > - feature_mask_dropout_prob).to(x.dtype)) - - # dim: (1, batch_size, 2) - mask = torch.cat((mask1, mask2), dim=-1) - - feature_masks = [] - for i in range(num_encoders): - channels = self.encoder_dim[i] - feature_mask = torch.ones(1, batch_size, channels, - dtype=x.dtype, device=x.device) - u1 = self.encoder_unmasked_dim[i] - u2 = u1 + (channels - u1) // 2 - - feature_mask[:, :, u1:u2] *= mask[..., 0:1] - feature_mask[:, :, u2:] *= mask[..., 1:2] - - feature_masks.append(feature_mask) - - return feature_masks - - def get_chunk_info(self) -> Tuple[int, int]: - """ - Returns chunk_size and left_context_chunks. - """ - if not self.causal: - return -1, -1 - - if torch.jit.is_scripting(): - assert len(self.chunk_size) == 1, self.chunk_size - chunk_size = self.chunk_size[0] - else: - chunk_size = random.choice(self.chunk_size) - - if chunk_size == -1: - left_context_chunks = -1 - else: - if torch.jit.is_scripting(): - assert len(self.left_context_frames) == 1, self.left_context_frames - left_context_frames = self.left_context_frames[0] - else: - left_context_frames = random.choice(self.left_context_frames) - # Note: in Python, -1 // n == -1 for n > 0 - left_context_chunks = left_context_frames // chunk_size - if left_context_chunks == 0: - left_context_chunks = 1 - - return chunk_size, left_context_chunks - - def forward( - self, x: Tensor, - x_lens: Tensor, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor]: - """ - Args: - x: - The input tensor. Its shape is (seq_len, batch_size, feature_dim). - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - `x` before padding. - src_key_padding_mask: - The mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - Returns: - Return a tuple containing 2 tensors: - - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) - - lengths, a tensor of shape (batch_size,) containing the number - of frames in `embeddings` before padding. - """ - outputs = [] - if torch.jit.is_scripting(): - feature_masks = [1.0] * len(self.encoder_dim) - else: - feature_masks = self.get_feature_masks(x) - - chunk_size, left_context_chunks = self.get_chunk_info() - - if torch.jit.is_scripting(): - # Not support exporting a model for simulating streaming decoding - attn_mask = None - else: - attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks) - - for i, module in enumerate(self.encoders): - ds = self.downsampling_factor[i] - x = convert_num_channels(x, self.encoder_dim[i]) - - x = module(x, - chunk_size=chunk_size, - feature_mask=feature_masks[i], - src_key_padding_mask=(None if src_key_padding_mask is None - else src_key_padding_mask[...,::ds]), - attn_mask=attn_mask) - outputs.append(x) - - # if the last output has the largest dimension, x will be unchanged, - # it will be the same as outputs[-1]. Otherwise it will be concatenated - # from different pieces of 'outputs', taking each dimension from the - # most recent output that has it present. - x = self._get_full_dim_output(outputs) - x = self.downsample_output(x) - # class Downsample has this rounding behavior.. - assert self.output_downsampling_factor == 2 - if torch.jit.is_scripting(): - lengths = (x_lens + 1) // 2 - else: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - lengths = (x_lens + 1) // 2 - - return x, lengths - - def _get_attn_mask( - self, x: Tensor, - chunk_size: int, - left_context_chunks: int - ) -> Optional[Tensor]: - """ - Return None if chunk_size == -1, else return attention mask of shape - (seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len). True - means a masked position. - Args: - x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim). - chunk_size: chunk size, must divide - """ - if chunk_size <= 0: - return None - assert all(chunk_size % d == 0 for d in self.downsampling_factor) - if left_context_chunks >= 0: - num_encoders = len(self.encoder_dim) - assert all (chunk_size * left_context_chunks >= - (self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i] - for i in range(num_encoders)) - else: - left_context_chunks = 1000000 - - seq_len = x.shape[0] - - # t is frame index, shape (seq_len,) - t = torch.arange(seq_len, dtype=torch.int32, device=x.device) - # c is chunk index for each frame, shape (seq_len,) - if torch.jit.is_scripting(): - c = t // chunk_size - else: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - c = t // chunk_size - src_c = c - tgt_c = c.unsqueeze(-1) - - attn_mask = torch.logical_or(src_c > tgt_c, - src_c < tgt_c - left_context_chunks) - if __name__ == "__main__": - logging.info(f"attn_mask = {attn_mask}") - return attn_mask - - def _get_full_dim_output(self, outputs: List[Tensor]): - num_encoders = len(self.encoder_dim) - assert len(outputs) == num_encoders - output_dim = max(self.encoder_dim) - output_pieces = [ outputs[-1] ] - cur_dim = self.encoder_dim[-1] - for i in range(num_encoders - 2, -1, -1): - d = self.encoder_dim[i] - if d > cur_dim: - this_output = outputs[i] - output_pieces.append(this_output[..., cur_dim:d]) - cur_dim = d - assert cur_dim == output_dim - return torch.cat(output_pieces, dim=-1) - - def streaming_forward( - self, - x: Tensor, - x_lens: Tensor, - states: List[Tensor], - src_key_padding_mask: Tensor, - ) -> Tuple[Tensor, Tensor, List[Tensor]]: - """ - Args: - x: - The input tensor. Its shape is (seq_len, batch_size, feature_dim). - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - `x` before padding. - states: list of cached tensors of all encoder layers. For layer-i, - states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, - cached_conv1, cached_conv2). - src_key_padding_mask: - The mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - Returns: - Return a tuple containing 2 tensors: - - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) - - lengths, a tensor of shape (batch_size,) containing the number - of frames in `embeddings` before padding. - - updated states - """ - outputs = [] - new_states = [] - layer_offset = 0 - - for i, module in enumerate(self.encoders): - num_layers = module.num_layers - ds = self.downsampling_factor[i] - x = convert_num_channels(x, self.encoder_dim[i]) - - x, new_layer_states = module.streaming_forward( - x, - states=states[layer_offset * 6 : (layer_offset + num_layers) * 6], - left_context_len=self.left_context_frames[0] // ds, - src_key_padding_mask=src_key_padding_mask[..., ::ds], - ) - layer_offset += num_layers - outputs.append(x) - new_states += new_layer_states - - # if the last output has the largest dimension, x will be unchanged, - # it will be the same as outputs[-1]. Otherwise it will be concatenated - # from different pieces of 'outputs', taking each dimension from the - # most recent output that has it present. - x = self._get_full_dim_output(outputs) - x = self.downsample_output(x) - # class Downsample has this rounding behavior.. - assert self.output_downsampling_factor == 2 - if torch.jit.is_scripting() or torch.jit.is_tracing(): - lengths = (x_lens + 1) // 2 - else: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - lengths = (x_lens + 1) // 2 - - return x, lengths, new_states - - @torch.jit.export - def get_init_states( - self, - batch_size: int = 1, - device: torch.device = torch.device("cpu"), - ) -> List[Tensor]: - """Get initial states. - - A list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] - is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - """ - states = [] - for i, module in enumerate(self.encoders): - num_layers = module.num_layers - embed_dim = self.encoder_dim[i] - ds = self.downsampling_factor[i] - num_heads = self.num_heads[i] - key_dim = self.query_head_dim[i] * num_heads - value_dim = self.value_head_dim[i] * num_heads - downsample_left = self.left_context_frames[0] // ds - nonlin_attn_head_dim = 3 * embed_dim // 4 - conv_left_pad = self.cnn_module_kernel[i] // 2 - for layer in range(num_layers): - cached_key = torch.zeros(downsample_left, batch_size, key_dim).to(device) - cached_nonlin_attn = torch.zeros(1, batch_size, downsample_left, nonlin_attn_head_dim).to(device) - cached_val1 = torch.zeros(downsample_left, batch_size, value_dim).to(device) - cached_val2 = torch.zeros(downsample_left, batch_size, value_dim).to(device) - cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).to(device) - cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).to(device) - states += [cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2] - - return states - - -def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: - return ScheduledFloat((0.0, x), - (20000.0, ratio * x), - default=x) - - -def _balancer_schedule(min_prob: float): - return ScheduledFloat((0.0, 0.4), (8000.0, min_prob)) - - -class Zipformer2EncoderLayer(nn.Module): - """ - Args: - embed_dim: the number of expected features in the input (required). - nhead: the number of heads in the multiheadattention models (required). - feedforward_dim: the dimension of the feedforward network model (default=2048). - dropout: the dropout value (default=0.1). - cnn_module_kernel (int): Kernel size of convolution module. - - Examples:: - >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) - >>> src = torch.rand(10, 32, 512) - >>> pos_emb = torch.rand(32, 19, 512) - >>> out = encoder_layer(src, pos_emb) - """ - def __init__( - self, - embed_dim: int, - pos_dim: int, - num_heads: int, - query_head_dim: int, - pos_head_dim: int, - value_head_dim: int, - feedforward_dim: int, - dropout: FloatLike = 0.1, - cnn_module_kernel: int = 31, - causal: bool = False, - attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), - conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), - const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0), - ff2_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)), - ff3_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)), - bypass_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.02), default=0), - ) -> None: - super(Zipformer2EncoderLayer, self).__init__() - self.embed_dim = embed_dim - - # self.bypass implements layer skipping as well as bypass; see its default values. - self.bypass = BypassModule(embed_dim, skip_rate=bypass_skip_rate, - straight_through_rate=0) - # bypass_mid is bypass used in the middle of the layer. - self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0) - - # skip probability for dynamic modules (meaning: anything but feedforward). - self.attention_skip_rate = copy.deepcopy(attention_skip_rate) - # an additional skip probability that applies to ConvModule to stop it from - # contributing too much early on. - self.conv_skip_rate = copy.deepcopy(conv_skip_rate) - - # ff2_skip_rate is to prevent the ff2 module from having output that's too big - # compared to its residual. - self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate) - self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate) - - self.const_attention_rate = copy.deepcopy(const_attention_rate) - - self.self_attn_weights = RelPositionMultiheadAttentionWeights( - embed_dim, pos_dim=pos_dim, num_heads=num_heads, - query_head_dim=query_head_dim, pos_head_dim=pos_head_dim, - dropout=0.0, - ) - - self.self_attn1 = SelfAttention(embed_dim, num_heads, - value_head_dim) - - self.self_attn2 = SelfAttention(embed_dim, num_heads, - value_head_dim) - - self.feed_forward1 = FeedforwardModule(embed_dim, - (feedforward_dim * 3) // 4, - dropout) - - self.feed_forward2 = FeedforwardModule(embed_dim, - feedforward_dim, - dropout) - - self.feed_forward3 = FeedforwardModule(embed_dim, - (feedforward_dim * 5) // 4, - dropout) - - self.nonlin_attention = NonlinAttention(embed_dim, - hidden_channels=3 * embed_dim // 4) - - self.conv_module1 = ConvolutionModule(embed_dim, - cnn_module_kernel, - causal=causal) - - self.conv_module2 = ConvolutionModule(embed_dim, - cnn_module_kernel, - causal=causal) - - # TODO: remove it - self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) - - self.norm = BiasNorm(embed_dim) - - self.balancer1 = Balancer( - embed_dim, channel_dim=-1, - min_positive=0.45, max_positive=0.55, - min_abs=0.2, max_abs=4.0, - ) - - # balancer for output of NonlinAttentionModule - self.balancer_na = Balancer( - embed_dim, channel_dim=-1, - min_positive=0.3, max_positive=0.7, - min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)), - prob=0.05, # out of concern for memory usage - ) - - # balancer for output of feedforward2, prevent it from staying too - # small. give this a very small probability, even at the start of - # training, it's to fix a rare problem and it's OK to fix it slowly. - self.balancer_ff2 = Balancer( - embed_dim, channel_dim=-1, - min_positive=0.3, max_positive=0.7, - min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0), - max_abs=2.0, - prob=0.05, - ) - - self.balancer_ff3 = Balancer( - embed_dim, channel_dim=-1, - min_positive=0.3, max_positive=0.7, - min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0), - max_abs=4.0, - prob=0.05, - ) - - self.whiten = Whiten(num_groups=1, - whitening_limit=_whitening_schedule(4.0, ratio=3.0), - prob=(0.025, 0.25), - grad_scale=0.01) - - self.balancer2 = Balancer( - embed_dim, channel_dim=-1, - min_positive=0.45, max_positive=0.55, - min_abs=0.1, max_abs=4.0, - ) - - def get_sequence_dropout_mask(self, x: Tensor, dropout_rate: float) -> Optional[Tensor]: - if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting(): - return None - batch_size = x.shape[1] - mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype) - return mask - - def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor: - """ - Apply sequence-level dropout to x. - x shape: (seq_len, batch_size, embed_dim) - """ - dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate) - if dropout_mask is None: - return x - else: - return x * dropout_mask - - def forward( - self, - src: Tensor, - pos_emb: Tensor, - chunk_size: int = -1, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - """ - Pass the input through the encoder layer. - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim) - chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. - feature_mask: something that broadcasts with src, that we'll multiply `src` - by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) - attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), - interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). - True means masked position. May be None. - src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - - Returns: - A tensor which has the same shape as src - """ - src_orig = src - - # dropout rate for non-feedforward submodules - if torch.jit.is_scripting(): - attention_skip_rate = 0.0 - else: - attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0 - - # attn_weights: (num_heads, batch_size, seq_len, seq_len) - attn_weights = self.self_attn_weights( - src, - pos_emb=pos_emb, - attn_mask=attn_mask, - key_padding_mask=src_key_padding_mask, - ) - - src = src + self.feed_forward1(src) - - self_attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate) - - selected_attn_weights = attn_weights[0:1] - if torch.jit.is_scripting(): - pass - elif not self.training and random.random() < float(self.const_attention_rate): - # Make attention weights constant. The intention is to - # encourage these modules to do something similar to an - # averaging-over-time operation. - # only need the mask, can just use the 1st one and expand later - selected_attn_weights = selected_attn_weights[0:1] - selected_attn_weights = (selected_attn_weights > 0.0).to(selected_attn_weights.dtype) - selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True)) - - na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights)) - - src = src + (na if self_attn_dropout_mask is None else na * self_attn_dropout_mask) - - self_attn = self.self_attn1(src, attn_weights) - - src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask) - - if torch.jit.is_scripting(): - conv_skip_rate = 0.0 - else: - conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 - src = src + self.sequence_dropout(self.conv_module1(src, chunk_size=chunk_size, - src_key_padding_mask=src_key_padding_mask), - conv_skip_rate) - - if torch.jit.is_scripting(): - ff2_skip_rate = 0.0 - else: - ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0 - src = src + self.sequence_dropout(self.balancer_ff2(self.feed_forward2(src)), - ff2_skip_rate) - - # bypass in the middle of the layer. - src = self.bypass_mid(src_orig, src) - - self_attn = self.self_attn2(src, attn_weights) - - src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask) - - if torch.jit.is_scripting(): - conv_skip_rate = 0.0 - else: - conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 - src = src + self.sequence_dropout(self.conv_module2(src, chunk_size=chunk_size, - src_key_padding_mask=src_key_padding_mask), - conv_skip_rate) - - if torch.jit.is_scripting(): - ff3_skip_rate = 0.0 - else: - ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0 - src = src + self.sequence_dropout(self.balancer_ff3(self.feed_forward3(src)), - ff3_skip_rate) - - src = self.balancer1(src) - src = self.norm(src) - - src = self.bypass(src_orig, src) - - src = self.balancer2(src) - src = self.whiten(src) - - return src - - def streaming_forward( - self, - src: Tensor, - pos_emb: Tensor, - cached_key: Tensor, - cached_nonlin_attn: Tensor, - cached_val1: Tensor, - cached_val2: Tensor, - cached_conv1: Tensor, - cached_conv2: Tensor, - left_context_len: int, - src_key_padding_mask: Tensor, - ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: - """Pass the input through the encoder layer in streaming forward mode. - - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - pos_emb: (1, left_context_len+2*seq_len-1, pos_emb_dim) or - (batch_size, left_context_len+2*seq_len-1, pos_emb_dim) - cached_key: cached attention key tensor of left context, - of shape (left_context_len, batch_size, key_dim) - cached_nonlin_attn: left context for nonlin_attention module, a Tensor of shape - (num_heads, batch_size, left_context_len, head_dim) - cached_val1: cached left context for the first attention module, - of shape (left_context_len, batch_size, value_dim) - cached_val2: cached left context for the second attention module, - of shape (left_context_len, batch_size, value_dim) - cached_conv1: cached left context for the first convolution module, - of shape (batch_size, channels, left_pad) - cached_conv2: cached left context for the second convolution module, - of shape (batch_size, channels, left_pad) - left_context_len: number of left context frames. - src_key_padding_mask: the mask for padding, of shape - (batch_size, left_context_len + seq_len); True means masked position. - May be None. - - Returns: - - x, with the same shape as src - - updated cached_key - - updated cached_nonlin_attn - - updated cached_val1 - - updated cached_val2 - - updated cached_conv1 - - updated cached_conv2 - """ - src_orig = src - - # attn_weights: (num_heads, batch_size, seq_len, seq_len) - attn_weights, cached_key = self.self_attn_weights.streaming_forward( - src, - pos_emb=pos_emb, - cached_key=cached_key, - left_context_len=left_context_len, - key_padding_mask=src_key_padding_mask, - ) - - src = src + self.feed_forward1(src) - - na, cached_nonlin_attn = self.nonlin_attention.streaming_forward( - src, - attn_weights[0:1], - cached_x=cached_nonlin_attn, - left_context_len=left_context_len, - ) - src = src + na - - self_attn, cached_val1 = self.self_attn1.streaming_forward( - src, - attn_weights=attn_weights, - cached_val=cached_val1, - left_context_len=left_context_len, - ) - src = src + self_attn - - src_conv, cached_conv1 = self.conv_module1.streaming_forward( - src, - cache=cached_conv1, - src_key_padding_mask=src_key_padding_mask[:, left_context_len:], - ) - src = src + src_conv - - src = src + self.feed_forward2(src) - - # bypass in the middle of the layer. - src = self.bypass_mid(src_orig, src) - - self_attn, cached_val2 = self.self_attn2.streaming_forward( - src, - attn_weights=attn_weights, - cached_val=cached_val2, - left_context_len=left_context_len, - ) - src = src + self_attn - - src_conv, cached_conv2 = self.conv_module2.streaming_forward( - src, - cache=cached_conv2, - src_key_padding_mask=src_key_padding_mask[:, left_context_len:], - ) - src = src + src_conv - - src = src + self.feed_forward3(src) - - src = self.norm(src) - - src = self.bypass(src_orig, src) - - return ( - src, - cached_key, - cached_nonlin_attn, - cached_val1, - cached_val2, - cached_conv1, - cached_conv2, - ) - - -class Zipformer2Encoder(nn.Module): - r"""Zipformer2Encoder is a stack of N encoder layers - - Args: - encoder_layer: an instance of the Zipformer2EncoderLayer() class (required). - num_layers: the number of sub-encoder-layers in the encoder (required). - pos_dim: the dimension for the relative positional encoding - - Examples:: - >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) - >>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6) - >>> src = torch.rand(10, 32, 512) - >>> out = zipformer_encoder(src) - """ - def __init__( - self, - encoder_layer: nn.Module, - num_layers: int, - pos_dim: int, - dropout: float, - warmup_begin: float, - warmup_end: float, - initial_layerdrop_rate: float = 0.5, - final_layerdrop_rate: float = 0.05, - ) -> None: - super().__init__() - self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.15, - length_factor=1.0) - - self.layers = nn.ModuleList( - [copy.deepcopy(encoder_layer) for i in range(num_layers)] - ) - self.num_layers = num_layers - - assert 0 <= warmup_begin <= warmup_end - - delta = (1. / num_layers) * (warmup_end - warmup_begin) - cur_begin = warmup_begin # interpreted as a training batch index - for i in range(num_layers): - cur_end = cur_begin + delta - self.layers[i].bypass.skip_rate = ScheduledFloat((cur_begin, initial_layerdrop_rate), - (cur_end, final_layerdrop_rate), - default=0.0) - cur_begin = cur_end - - def forward( - self, - src: Tensor, - chunk_size: int = -1, - feature_mask: Union[Tensor, float] = 1.0, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - r"""Pass the input through the encoder layers in turn. - - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. - feature_mask: something that broadcasts with src, that we'll multiply `src` - by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) - attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), - interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). - True means masked position. May be None. - src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - - Returns: a Tensor with the same shape as src. - """ - pos_emb = self.encoder_pos(src) - output = src - - if not torch.jit.is_scripting(): - output = output * feature_mask - - for i, mod in enumerate(self.layers): - output = mod( - output, - pos_emb, - chunk_size=chunk_size, - attn_mask=attn_mask, - src_key_padding_mask=src_key_padding_mask, - ) - - if not torch.jit.is_scripting(): - output = output * feature_mask - - return output - - def streaming_forward( - self, - src: Tensor, - states: List[Tensor], - left_context_len: int, - src_key_padding_mask: Tensor, - ) -> Tuple[Tensor, List[Tensor]]: - r"""Pass the input through the encoder layers in turn. - - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is - (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - left_context_len: Number of left context frames. - src_key_padding_mask: the mask for padding, of shape - (batch_size, left_context_len + seq_len); True means masked position. - May be None. - - Returns: - - output, a Tensor with the same shape as src. - - updated states - """ - pos_emb = self.encoder_pos(src, left_context_len) - output = src - - new_states = [] - for i, mod in enumerate(self.layers): - ( - cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2 - ) = states[i * 6: (i + 1) * 6] - ( - output, - new_cached_key, - new_cached_nonlin_attn, - new_cached_val1, - new_cached_val2, - new_cached_conv1, - new_cached_conv2 - ) = mod.streaming_forward( - output, - pos_emb, - cached_key=cached_key, - cached_nonlin_attn=cached_nonlin_attn, - cached_val1=cached_val1, - cached_val2=cached_val2, - cached_conv1=cached_conv1, - cached_conv2=cached_conv2, - left_context_len=left_context_len, - src_key_padding_mask=src_key_padding_mask, - ) - new_states += [ - new_cached_key, - new_cached_nonlin_attn, - new_cached_val1, - new_cached_val2, - new_cached_conv1, - new_cached_conv2, - ] - - return output, new_states - - -class BypassModule(nn.Module): - """ - An nn.Module that implements a learnable bypass scale, and also randomized per-sequence - layer-skipping. The bypass is limited during early stages of training to be close to - "straight-through", i.e. to not do the bypass operation much initially, in order to - force all the modules to learn something. - """ - def __init__( - self, - embed_dim: int, - skip_rate: FloatLike = 0.0, - straight_through_rate: FloatLike = 0.0, - scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0), - scale_max: FloatLike = 1.0): - super().__init__() - self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) - self.skip_rate = copy.deepcopy(skip_rate) - self.straight_through_rate = copy.deepcopy(straight_through_rate) - self.scale_min = copy.deepcopy(scale_min) - self.scale_max = copy.deepcopy(scale_max) - - def _get_bypass_scale(self, batch_size: int): - # returns bypass-scale of shape (num_channels,), - # or (batch_size, num_channels,). This is actually the - # scale on the non-residual term, so 0 correponds to bypassing - # this module. - if torch.jit.is_scripting() or not self.training: - return self.bypass_scale - else: - ans = limit_param_value(self.bypass_scale, - min=float(self.scale_min), - max=float(self.scale_max)) - skip_rate = float(self.skip_rate) - if skip_rate != 0.0: - mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate - ans = ans * mask - # now ans is of shape (batch_size, num_channels), and is zero for sequences - # on which we have randomly chosen to do layer-skipping. - straight_through_rate = float(self.straight_through_rate) - if straight_through_rate != 0.0: - mask = torch.rand((batch_size, 1), device=ans.device) < straight_through_rate - ans = torch.maximum(ans, mask.to(ans.dtype)) - return ans - - def forward(self, - src_orig: Tensor, - src: Tensor): - """ - Args: src_orig and src are both of shape (seq_len, batch_size, num_channels) - Returns: something with the same shape as src and src_orig - """ - bypass_scale = self._get_bypass_scale(src.shape[1]) - return src_orig + (src - src_orig) * bypass_scale - - -class DownsampledZipformer2Encoder(nn.Module): - r""" - DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame rate, - after convolutional downsampling, and then upsampled again at the output, and combined - with the origin input, so that the output has the same shape as the input. - """ - def __init__(self, - encoder: nn.Module, - dim: int, - downsample: int, - dropout: FloatLike): - super(DownsampledZipformer2Encoder, self).__init__() - self.downsample_factor = downsample - self.downsample = SimpleDownsample(dim, - downsample, dropout) - self.num_layers = encoder.num_layers - self.encoder = encoder - self.upsample = SimpleUpsample(dim, downsample) - self.out_combiner = BypassModule(dim, straight_through_rate=0) - - def forward( - self, - src: Tensor, - chunk_size: int = -1, - feature_mask: Union[Tensor, float] = 1.0, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - r"""Downsample, go through encoder, upsample. - - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - feature_mask: something that broadcasts with src, that we'll multiply `src` - by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) - attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), - interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). - True means masked position. May be None. - src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - - Returns: a Tensor with the same shape as src. - """ - src_orig = src - src = self.downsample(src) - ds = self.downsample_factor - if attn_mask is not None: - attn_mask = attn_mask[::ds,::ds] - - src = self.encoder( - src, - chunk_size=chunk_size // ds, - feature_mask=feature_mask, - attn_mask=attn_mask, - src_key_padding_mask=src_key_padding_mask, - ) - src = self.upsample(src) - # remove any extra frames that are not a multiple of downsample_factor - src = src[:src_orig.shape[0]] - - return self.out_combiner(src_orig, src) - - def streaming_forward( - self, - src: Tensor, - states: List[Tensor], - left_context_len: int, - src_key_padding_mask: Tensor, - ) -> Tuple[Tensor, List[Tensor]]: - r"""Downsample, go through encoder, upsample, in streaming forward mode. - - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is - (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - left_context_len: Number of left context frames. - src_key_padding_mask: the mask for padding, of shape (batch_size, left_context_len+seq_len); - True means masked position. May be None. - - Returns: - - output, a Tensor with the same shape as src. - - updated states - """ - src_orig = src - src = self.downsample(src) - - src, new_states = self.encoder.streaming_forward( - src, - states=states, - left_context_len=left_context_len, - src_key_padding_mask=src_key_padding_mask, - ) - src = self.upsample(src) - # remove any extra frames that are not a multiple of downsample_factor - src = src[:src_orig.shape[0]] - - return self.out_combiner(src_orig, src), new_states - - -class SimpleDownsample(torch.nn.Module): - """ - Does downsampling with attention, by weighted sum, and a projection.. - """ - def __init__(self, - channels: int, - downsample: int, - dropout: FloatLike): - super(SimpleDownsample, self).__init__() - - self.bias = nn.Parameter(torch.zeros(downsample)) - - self.name = None # will be set from training code - self.dropout = copy.deepcopy(dropout) - - self.downsample = downsample - - def forward(self, - src: Tensor) -> Tensor: - """ - x: (seq_len, batch_size, in_channels) - Returns a tensor of shape - ( (seq_len+downsample-1)//downsample, batch_size, channels) - """ - (seq_len, batch_size, in_channels) = src.shape - ds = self.downsample - d_seq_len = (seq_len + ds - 1) // ds - - # Pad to an exact multiple of self.downsample - if seq_len != d_seq_len * ds: - # right-pad src, repeating the last element. - pad = d_seq_len * ds - seq_len - src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2]) - src = torch.cat((src, src_extra), dim=0) - assert src.shape[0] == d_seq_len * ds - - src = src.reshape(d_seq_len, ds, batch_size, in_channels) - - weights = self.bias.softmax(dim=0) - # weights: (downsample, 1, 1) - weights = weights.unsqueeze(-1).unsqueeze(-1) - - # ans1 is the first `in_channels` channels of the output - ans = (src * weights).sum(dim=1) - - return ans - - -class SimpleUpsample(torch.nn.Module): - """ - A very simple form of upsampling that mostly just repeats the input, but - also adds a position-specific bias. - """ - def __init__(self, - num_channels: int, - upsample: int): - super(SimpleUpsample, self).__init__() - self.upsample = upsample - - def forward(self, - src: Tensor) -> Tensor: - """ - x: (seq_len, batch_size, num_channels) - Returns a tensor of shape - ( (seq_len*upsample), batch_size, num_channels) - """ - upsample = self.upsample - (seq_len, batch_size, num_channels) = src.shape - src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels) - src = src.reshape(seq_len * upsample, batch_size, num_channels) - return src - - -class CompactRelPositionalEncoding(torch.nn.Module): - """ - Relative positional encoding module. This version is "compact" meaning it is able to encode - the important information about the relative position in a relatively small number of dimensions. - The goal is to make it so that small differences between large relative offsets (e.g. 1000 vs. 1001) - make very little difference to the embedding. Such differences were potentially important - when encoding absolute position, but not important when encoding relative position because there - is now no need to compare two large offsets with each other. - - Our embedding works done by projecting the interval [-infinity,infinity] to a finite interval - using the atan() function, before doing the fourier transform of that fixed interval. The - atan() function would compress the "long tails" too small, - making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic - function to compress large offsets to a smaller range before applying atan(). - Scalings are chosen in such a way that the embedding can clearly distinguish invidual offsets as long - as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embedding_dim) - - - Args: - embed_dim: Embedding dimension. - dropout_rate: Dropout rate. - max_len: Maximum input length: just a heuristic for initialization. - length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives - less weight to small differences of offset near the origin. - """ - def __init__( - self, embed_dim: int, - dropout_rate: FloatLike, - max_len: int = 1000, - length_factor: float = 1.0, - ) -> None: - """Construct a CompactRelPositionalEncoding object.""" - super(CompactRelPositionalEncoding, self).__init__() - self.embed_dim = embed_dim - assert embed_dim % 2 == 0 - self.dropout = Dropout2(dropout_rate) - self.pe = None - assert length_factor >= 1.0 - self.length_factor = length_factor - self.extend_pe(torch.tensor(0.0).expand(max_len)) - - def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None: - """Reset the positional encodings.""" - T = x.size(0) + left_context_len - - if self.pe is not None: - # self.pe contains both positive and negative parts - # the length of self.pe is 2 * input_len - 1 - if self.pe.size(0) >= T * 2 - 1: - # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - - # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ] - x = torch.arange(-(T-1), T, - device=x.device).to(torch.float32).unsqueeze(1) - - freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device) - - # `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution - # for small time offsets but less resolution for large time offsets. - compression_length = (self.embed_dim ** 0.5) - # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity; - # but it does so more slowly than T for large absolute values of T. - # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which - # is important. - x_compressed = compression_length * x.sign() * ((x.abs() + compression_length).log() - math.log(compression_length)) - - # if self.length_factor == 1.0, then length_scale is chosen so that the - # FFT can exactly separate points close to the origin (T == 0). So this - # part of the formulation is not really heuristic. - # But empirically, for ASR at least, length_factor > 1.0 seems to work better. - length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi) - - # note for machine implementations: if atan is not available, we can use: - # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2) - # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x)) - x_atan = (x_compressed / length_scale).atan() # results between -pi and pi - - cosines = (x_atan * freqs).cos() - sines = (x_atan * freqs).sin() - - pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device) - pe[:, 0::2] = cosines - pe[:, 1::2] = sines - pe[:, -1] = 1.0 # for bias. - - self.pe = pe.to(dtype=x.dtype) - - def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor: - """Create positional encoding. - - Args: - x (Tensor): Input tensor (time, batch, `*`). - left_context_len: (int): Length of cached left context. - - Returns: - positional embedding, of shape (batch, left_context_len + 2*time-1, `*`). - """ - self.extend_pe(x, left_context_len) - x_size_left = x.size(0) + left_context_len - # length of positive side: x.size(0) + left_context_len - # length of negative side: x.size(0) - pos_emb = self.pe[ - self.pe.size(0) // 2 - - x_size_left - + 1 : self.pe.size(0) // 2 # noqa E203 - + x.size(0), - : - ] - pos_emb = pos_emb.unsqueeze(0) - return self.dropout(pos_emb) - - -class RelPositionMultiheadAttentionWeights(nn.Module): - r"""Module that computes multi-head attention weights with relative position encoding. - Various other modules consume the resulting attention weights: see, for example, the - SimpleAttention module which allows you to compute conventional attention. - - This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", - we have to write up the differences. - - - Args: - embed_dim: number of channels at the input to this module, e.g. 256 - pos_dim: dimension of the positional encoding vectors, e.g. 128. - num_heads: number of heads to compute weights for, e.g. 8 - query_head_dim: dimension of the query (and key), per head. e.g. 24. - pos_head_dim: dimension of the projected positional encoding per head, e.g. 4. - dropout: dropout probability for attn_output_weights. Default: 0.0. - pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on - any given call to forward(), in training time. - """ - - def __init__( - self, - embed_dim: int, - pos_dim: int, - num_heads: int, - query_head_dim: int, - pos_head_dim: int, - dropout: float = 0.0, - pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), - (4000.0, 0.0)) - ) -> None: - super().__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.query_head_dim = query_head_dim - self.pos_head_dim = pos_head_dim - self.dropout = dropout - self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate) - self.name = None # will be overwritten in training code; for diagnostics. - - key_head_dim = query_head_dim - in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads - - # the initial_scale is supposed to take over the "scaling" factor of - # head_dim ** -0.5 that has been used in previous forms of attention, - # dividing it between the query and key. Note: this module is intended - # to be used with the ScaledAdam optimizer; with most other optimizers, - # it would be necessary to apply the scaling factor in the forward function. - self.in_proj = ScaledLinear(embed_dim, in_proj_dim, bias=True, - initial_scale=query_head_dim**-0.25) - - self.whiten_keys = Whiten(num_groups=num_heads, - whitening_limit=_whitening_schedule(3.0), - prob=(0.025, 0.25), - grad_scale=0.025) - - # add a balancer for the keys that runs with very small probability, and - # tries to enforce that all dimensions have mean around zero. The - # weights produced by this module are invariant to adding a constant to - # the keys, so the derivative of the bias is mathematically zero; but - # due to how Adam/ScaledAdam work, it can learn a fairly large nonzero - # bias because the small numerical roundoff tends to have a non-random - # sign. This module is intended to prevent that. Use a very small - # probability; that should be suffixient to fix the problem. - self.balance_keys = Balancer(key_head_dim * num_heads, - channel_dim=-1, - min_positive=0.4, - max_positive=0.6, - min_abs=0.0, - max_abs=100.0, - prob=0.025) - - # linear transformation for positional encoding. - self.linear_pos = ScaledLinear(pos_dim, - num_heads * pos_head_dim, - bias=False, - initial_scale=0.05) - - # the following are for diagnosics only, see --print-diagnostics option - self.copy_pos_query = Identity() - self.copy_query = Identity() - - def forward( - self, - x: Tensor, - pos_emb: Tensor, - key_padding_mask: Optional[Tensor] = None, - attn_mask: Optional[Tensor] = None, - ) -> Tensor: - r""" - Args: - x: input of shape (seq_len, batch_size, embed_dim) - pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim) - key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that - are True in this mask will be ignored as sources in the attention weighting. - attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len), - interpreted as ([batch_size,] tgt_seq_len, src_seq_len) - saying which positions are allowed to attend to which other positions. - Returns: - a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len) - interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). - """ - x = self.in_proj(x) - query_head_dim = self.query_head_dim - pos_head_dim = self.pos_head_dim - num_heads = self.num_heads - - seq_len, batch_size, _ = x.shape - - query_dim = query_head_dim * num_heads - - # self-attention - q = x[...,0:query_dim] - k = x[...,query_dim:2*query_dim] - # p is the position-encoding query - p = x[...,2*query_dim:] - assert p.shape[-1] == num_heads * pos_head_dim - - q = self.copy_query(q) # for diagnostics only, does nothing. - k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass. - p = self.copy_pos_query(p) # for diagnostics only, does nothing. - - q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) - p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) - k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) - - # time1 refers to target, time2 refers to source. - q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) - p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) - k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) - - attn_scores = torch.matmul(q, k) - - use_pos_scores = False - if torch.jit.is_scripting(): - # We can't put random.random() in the same line - use_pos_scores = True - elif not self.training or random.random() >= float(self.pos_emb_skip_rate): - use_pos_scores = True - - if use_pos_scores: - pos_emb = self.linear_pos(pos_emb) - seq_len2 = 2 * seq_len - 1 - pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1) - # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) - - # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) - # [where seq_len2 represents relative position.] - pos_scores = torch.matmul(p, pos_emb) - # the following .as_strided() expression converts the last axis of pos_scores from relative - # to absolute position. I don't know whether I might have got the time-offsets backwards or - # not, but let this code define which way round it is supposed to be. - pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, seq_len), - (pos_scores.stride(0), - pos_scores.stride(1), - pos_scores.stride(2)-pos_scores.stride(3), - pos_scores.stride(3)), - storage_offset=pos_scores.stride(3) * (seq_len - 1)) - - attn_scores = attn_scores + pos_scores - - if torch.jit.is_scripting(): - pass - elif self.training and random.random() < 0.1: - # This is a harder way of limiting the attention scores to not be - # too large. It incurs a penalty if any of them has an absolute - # value greater than 50.0. this should be outside the normal range - # of the attention scores. We use this mechanism instead of, say, - # something added to the loss function involving the entropy, - # because once the entropy gets very small gradients through the - # softmax can become very small, and we'd get zero derivatives. The - # choices of 1.0e-04 as the scale on the penalty makes this - # mechanism vulnerable to the absolute scale of the loss function, - # but we view this as a failsafe to avoid "implausible" parameter - # values rather than a regularization method that should be active - # under normal circumstances. - attn_scores = penalize_abs_values_gt(attn_scores, - limit=25.0, - penalty=1.0e-04, - name=self.name) - - assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) - - if attn_mask is not None: - assert attn_mask.dtype == torch.bool - # use -1000 to avoid nan's where attn_mask and key_padding_mask make - # all scores zero. It's important that this be large enough that exp(-1000) - # is exactly zero, for reasons related to const_attention_rate, it - # compares the final weights with zero. - attn_scores = attn_scores.masked_fill(attn_mask, -1000) - - if key_padding_mask is not None: - assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape - attn_scores = attn_scores.masked_fill( - key_padding_mask.unsqueeze(1), - -1000, - ) - - # We use our own version of softmax, defined in scaling.py, which should - # save a little of the memory used in backprop by, if we are in - # automatic mixed precision mode (amp / autocast), by only storing the - # half-precision output for backprop purposes. - attn_weights = softmax(attn_scores, dim=-1) - - if torch.jit.is_scripting(): - pass - elif random.random() < 0.001 and not self.training: - self._print_attn_entropy(attn_weights) - - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) - - return attn_weights - - def streaming_forward( - self, - x: Tensor, - pos_emb: Tensor, - cached_key: Tensor, - left_context_len: int, - key_padding_mask: Tensor, - ) -> Tuple[Tensor, Tensor]: - r""" - Args: - x: input of shape (seq_len, batch_size, embed_dim) - pos_emb: Positional embedding tensor, of shape (1, left_context_len+2*seq_len-1, pos_dim) - cached_key: cached attention key tensor of left context, - of shape (left_context_len, batch_size, key_dim) - left_context_len: number of left context frames. - key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that - are True in this mask will be ignored as sources in the attention weighting. - - Returns: - - attention weights, of shape (hum_heads, batch_size, seq_len, seq_len2), - interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). - - updated cached attention key tensor of left context. - """ - x = self.in_proj(x) - query_head_dim = self.query_head_dim - pos_head_dim = self.pos_head_dim - num_heads = self.num_heads - - seq_len, batch_size, _ = x.shape - - query_dim = query_head_dim * num_heads - - # self-attention - q = x[...,0:query_dim] - k = x[...,query_dim:2*query_dim] - # p is the position-encoding query - p = x[...,2*query_dim:] - assert p.shape[-1] == num_heads * pos_head_dim - - # Pad cached left contexts - assert cached_key.shape[0] == left_context_len, (cached_key.shape[0], left_context_len) - k = torch.cat([cached_key, k], dim=0) - # Update cached left contexts - cached_key = k[-left_context_len:, ...] - - # The length of key - k_len = k.shape[0] - - q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) - p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) - k = k.reshape(k_len, batch_size, num_heads, query_head_dim) - - # time1 refers to target, time2 refers to source. - q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) - p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) - k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) - - attn_scores = torch.matmul(q, k) - - pos_emb = self.linear_pos(pos_emb) - seq_len2 = 2 * seq_len - 1 + left_context_len - pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1) - # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) - - # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) - # [where seq_len2 represents relative position.] - pos_scores = torch.matmul(p, pos_emb) - # the following .as_strided() expression converts the last axis of pos_scores from relative - # to absolute position. I don't know whether I might have got the time-offsets backwards or - # not, but let this code define which way round it is supposed to be. - pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, k_len), - (pos_scores.stride(0), - pos_scores.stride(1), - pos_scores.stride(2)-pos_scores.stride(3), - pos_scores.stride(3)), - storage_offset=pos_scores.stride(3) * (seq_len - 1)) - - attn_scores = attn_scores + pos_scores - - assert attn_scores.shape == (num_heads, batch_size, seq_len, k_len), attn_scores.shape - - if key_padding_mask is not None: - assert key_padding_mask.shape == (batch_size, k_len), key_padding_mask.shape - attn_scores = attn_scores.masked_fill( - key_padding_mask.unsqueeze(1), - -1000, - ) - - attn_weights = attn_scores.softmax(dim=-1) - - return attn_weights, cached_key - - def _print_attn_entropy( - self, - attn_weights: Tensor): - # attn_weights: (num_heads, batch_size, seq_len, seq_len) - (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape - - with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): - attn_weights = attn_weights.to(torch.float32) - attn_weights_entropy = -((attn_weights + 1.0e-20).log() * attn_weights).sum( - dim=-1).mean(dim=(1,2)) - logging.info(f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}") - - -class SelfAttention(nn.Module): - """ - The simplest possible attention module. This one works with already-computed attention - weights, e.g. as computed by RelPositionMultiheadAttentionWeights. - - Args: - embed_dim: the input and output embedding dimension - num_heads: the number of attention heads - value_head_dim: the value dimension per head - """ - def __init__( - self, - embed_dim: int, - num_heads: int, - value_head_dim: int, - ) -> None: - super().__init__() - self.in_proj = nn.Linear(embed_dim, - num_heads * value_head_dim, - bias=True) - - self.out_proj = ScaledLinear(num_heads * value_head_dim, - embed_dim, bias=True, - initial_scale=0.05) - - self.whiten = Whiten(num_groups=1, - whitening_limit=_whitening_schedule(7.5, ratio=3.0), - prob=(0.025, 0.25), - grad_scale=0.01) - - def forward( - self, - x: Tensor, - attn_weights: Tensor, - ) -> Tensor: - """ - Args: - x: input tensor, of shape (seq_len, batch_size, embed_dim) - attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), - with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect - attn_weights.sum(dim=-1) == 1. - Returns: - a tensor with the same shape as x. - """ - (seq_len, batch_size, embed_dim) = x.shape - num_heads = attn_weights.shape[0] - assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) - - x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) - x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) - # now x: (num_heads, batch_size, seq_len, value_head_dim) - value_head_dim = x.shape[-1] - - # todo: see whether there is benefit in overriding matmul - x = torch.matmul(attn_weights, x) - # v: (num_heads, batch_size, seq_len, value_head_dim) - - x = x.permute(2, 1, 0, 3).contiguous().view( - seq_len, batch_size, num_heads * value_head_dim) - - # returned value is of shape (seq_len, batch_size, embed_dim), like the input. - x = self.out_proj(x) - x = self.whiten(x) - - return x - - def streaming_forward( - self, - x: Tensor, - attn_weights: Tensor, - cached_val: Tensor, - left_context_len: int, - ) -> Tuple[Tensor, Tensor]: - """ - Args: - x: input tensor, of shape (seq_len, batch_size, embed_dim) - attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), - with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect - attn_weights.sum(dim=-1) == 1. - cached_val: cached attention value tensor of left context, - of shape (left_context_len, batch_size, value_dim) - left_context_len: number of left context frames. - - Returns: - - attention weighted output, a tensor with the same shape as x. - - updated cached attention value tensor of left context. - """ - (seq_len, batch_size, embed_dim) = x.shape - num_heads = attn_weights.shape[0] - seq_len2 = seq_len + left_context_len - assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len2) - - x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) - - # Pad cached left contexts - assert cached_val.shape[0] == left_context_len, (cached_val.shape[0], left_context_len) - x = torch.cat([cached_val, x], dim=0) - # Update cached left contexts - cached_val = x[-left_context_len:, ...] - - x = x.reshape(seq_len2, batch_size, num_heads, -1).permute(2, 1, 0, 3) - # now x: (num_heads, batch_size, seq_len, value_head_dim) - value_head_dim = x.shape[-1] - - # todo: see whether there is benefit in overriding matmul - x = torch.matmul(attn_weights, x) - # v: (num_heads, batch_size, seq_len, value_head_dim) - - x = x.permute(2, 1, 0, 3).contiguous().view( - seq_len, batch_size, num_heads * value_head_dim) - - # returned value is of shape (seq_len, batch_size, embed_dim), like the input. - x = self.out_proj(x) - - return x, cached_val - - -class FeedforwardModule(nn.Module): - """Feedforward module in Zipformer2 model. - """ - def __init__(self, - embed_dim: int, - feedforward_dim: int, - dropout: FloatLike): - super(FeedforwardModule, self).__init__() - self.in_proj = nn.Linear(embed_dim, feedforward_dim) - - self.hidden_balancer = Balancer(feedforward_dim, - channel_dim=-1, - min_positive=0.3, - max_positive=1.0, - min_abs=0.75, - max_abs=5.0) - - # shared_dim=0 means we share the dropout mask along the time axis - self.out_proj = ActivationDropoutAndLinear(feedforward_dim, embed_dim, - activation='SwooshL', - dropout_p=dropout, - dropout_shared_dim=0, bias=True, - initial_scale=0.1) - - self.out_whiten = Whiten(num_groups=1, - whitening_limit=_whitening_schedule(7.5), - prob=(0.025, 0.25), - grad_scale=0.01) - - def forward(self, x: Tensor): - x = self.in_proj(x) - x = self.hidden_balancer(x) - # out_proj contains SwooshL activation, then dropout, then linear. - x = self.out_proj(x) - x = self.out_whiten(x) - return x - - -class NonlinAttention(nn.Module): - """This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed - from the attention module) in place of actual convolution. We also took out the second nonlinearity, the - one after the attention mechanism. - - Args: - channels (int): The number of channels of conv layers. - """ - - def __init__( - self, - channels: int, - hidden_channels: int, - ) -> None: - super().__init__() - - self.hidden_channels = hidden_channels - - self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True) - - # balancer that goes before the sigmoid. Have quite a large min_abs value, at 2.0, - # because we noticed that well-trained instances of this module have abs-value before the sigmoid - # starting from about 3, and poorly-trained instances of the module have smaller abs values - # before the sigmoid. - self.balancer = Balancer( - hidden_channels, channel_dim=-1, - min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)), - max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)), - min_abs=0.5, - max_abs=5.0, - ) - self.tanh = nn.Tanh() - - self.identity1 = Identity() # for diagnostics. - self.identity2 = Identity() # for diagnostics. - self.identity3 = Identity() # for diagnostics. - - self.out_proj = ScaledLinear(hidden_channels, channels, - bias=True, - initial_scale=0.05) - - self.whiten1 = Whiten(num_groups=1, - whitening_limit=_whitening_schedule(5.0), - prob=(0.025, 0.25), - grad_scale=0.01) - - self.whiten2 = Whiten(num_groups=1, - whitening_limit=_whitening_schedule(5.0, ratio=3.0), - prob=(0.025, 0.25), - grad_scale=0.01) - - def forward( - self, - x: Tensor, - attn_weights: Tensor, - ) -> Tensor: - """. - Args: - x: a Tensor of shape (seq_len, batch_size, num_channels) -attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) - Returns: - a Tensor with the same shape as x - """ - x = self.in_proj(x) - - (seq_len, batch_size, _) = x.shape - hidden_channels = self.hidden_channels - - s, x, y = x.chunk(3, dim=-1) - - # s will go through tanh. - - s = self.balancer(s) - s = self.tanh(s) - - s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) - x = self.whiten1(x) - x = x * s - x = self.identity1(x) # diagnostics only, it's the identity. - - (seq_len, batch_size, embed_dim) = x.shape - num_heads = attn_weights.shape[0] - assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) - - x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) - # now x: (num_heads, batch_size, seq_len, head_dim) - x = torch.matmul(attn_weights, x) - # now x: (num_heads, batch_size, seq_len, head_dim) - x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) - - y = self.identity2(y) - x = x * y - x = self.identity3(x) - - x = self.out_proj(x) - x = self.whiten2(x) - return x - - def streaming_forward( - self, - x: Tensor, - attn_weights: Tensor, - cached_x: Tensor, - left_context_len: int, - ) -> Tuple[Tensor, Tensor]: - """. - Args: - x: a Tensor of shape (seq_len, batch_size, num_channels) - attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) - cached_x: left context, a Tensor of shape - (num_heads, batch_size, left_context_len, head_dim) - left_context_len: number of left context frames. - Returns: - - a Tensor with the same shape as x - - updated left context with same shape as cached_x - """ - x = self.in_proj(x) - - (seq_len, batch_size, _) = x.shape - hidden_channels = self.hidden_channels - - s, x, y = x.chunk(3, dim=-1) - - # s will go through tanh. - s = self.tanh(s) - - s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) - x = x * s - - (seq_len, batch_size, embed_dim) = x.shape - num_heads = attn_weights.shape[0] - assert attn_weights.shape == (num_heads, batch_size, seq_len, left_context_len + seq_len) - - x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) - # now x: (num_heads, batch_size, seq_len, head_dim) - - # Pad cached tensor - assert cached_x.shape[2] == left_context_len, (cached_x.shape[2], left_context_len) - x_pad = torch.cat([cached_x, x], dim=2) - # Update cached tensor - cached_x = x_pad[:, :, -left_context_len:, :] - - x = torch.matmul(attn_weights, x_pad) - # now x: (num_heads, batch_size, seq_len, head_dim) - x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) - - x = x * y - - x = self.out_proj(x) - return x, cached_x - - -class ConvolutionModule(nn.Module): - """ConvolutionModule in Zipformer2 model. - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py - - Args: - channels (int): The number of channels of conv layers. - kernel_size (int): Kernerl size of conv layers. - bias (bool): Whether to use bias in conv layers (default=True). - - """ - def __init__( - self, channels: int, kernel_size: int, causal: bool, - ) -> None: - """Construct a ConvolutionModule object.""" - super(ConvolutionModule, self).__init__() - # kernerl_size should be a odd number for 'SAME' padding - assert (kernel_size - 1) % 2 == 0 - - bottleneck_dim = channels - self.causal = causal - - self.in_proj = nn.Linear( - channels, 2 * bottleneck_dim, - ) - # the gradients on in_proj are a little noisy, likely to do with the - # sigmoid in glu. - - # after in_proj we put x through a gated linear unit (nn.functional.glu). - # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, - # but sometimes, for some reason, for layer 0 the rms ends up being very large, - # between 50 and 100 for different channels. This will cause very peaky and - # sparse derivatives for the sigmoid gating function, which will tend to make - # the loss function not learn effectively. (for most layers the average absolute values - # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, - # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different - # layers, which likely breaks down as 0.5 for the "linear" half and - # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we - # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, - # it will be in a better position to start learning something, i.e. to latch onto - # the correct range. - self.balancer1 = Balancer( - bottleneck_dim, channel_dim=-1, - min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)), - max_positive=1.0, - min_abs=1.5, - max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0), - ) - - self.activation1 = Identity() # for diagnostics - - self.sigmoid = nn.Sigmoid() - - self.activation2 = Identity() # for diagnostics - - assert kernel_size % 2 == 1 - - self.depthwise_conv = ChunkCausalDepthwiseConv1d( - channels=bottleneck_dim, - kernel_size=kernel_size) if causal else nn.Conv1d( - in_channels=bottleneck_dim, - out_channels=bottleneck_dim, - groups=bottleneck_dim, - kernel_size=kernel_size, - padding=kernel_size // 2) - - self.balancer2 = Balancer( - bottleneck_dim, channel_dim=1, - min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)), - max_positive=1.0, - min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)), - max_abs=10.0, - ) - - self.whiten = Whiten(num_groups=1, - whitening_limit=_whitening_schedule(7.5), - prob=(0.025, 0.25), - grad_scale=0.01) - - self.out_proj = ActivationDropoutAndLinear( - bottleneck_dim, channels, activation='SwooshR', - dropout_p=0.0, initial_scale=0.05, - ) - - def forward( - self, - x: Tensor, - src_key_padding_mask: Optional[Tensor] = None, - chunk_size: int = -1, - ) -> Tensor: - """Compute convolution module. - - Args: - x: Input tensor (#time, batch, channels). - src_key_padding_mask: the mask for the src keys per batch (optional): - (batch, #time), contains True in masked positions. - - Returns: - Tensor: Output tensor (#time, batch, channels). - - """ - - x = self.in_proj(x) # (time, batch, 2*channels) - - x, s = x.chunk(2, dim=-1) - s = self.balancer1(s) - s = self.sigmoid(s) - x = self.activation1(x) # identity. - x = x * s - x = self.activation2(x) # identity - - # (time, batch, channels) - - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - - if src_key_padding_mask is not None: - x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) - - if not torch.jit.is_scripting() and chunk_size >= 0: - # Not support exporting a model for simulated streaming decoding - assert self.causal, "Must initialize model with causal=True if you use chunk_size" - x = self.depthwise_conv(x, chunk_size=chunk_size) - else: - x = self.depthwise_conv(x) - - x = self.balancer2(x) - x = x.permute(2, 0, 1) # (time, batch, channels) - - x = self.whiten(x) # (time, batch, channels) - x = self.out_proj(x) # (time, batch, channels) - - return x - - def streaming_forward( - self, - x: Tensor, - cache: Tensor, - src_key_padding_mask: Tensor, - ) -> Tuple[Tensor, Tensor]: - """Compute convolution module in streaming forward mode. - - Args: - x: Input tensor (#time, batch, channels). - cache: cached left context for depthwise_conv of shape - (#batch, channels, left_pad) - src_key_padding_mask: the mask for the src keys per batch (optional): - (batch, #time), contains True in masked positions. - - Returns: - - Output tensor (#time, batch, channels). - - Updated cache (#batch, channels, left_pad) - """ - - x = self.in_proj(x) # (time, batch, 2*channels) - - x, s = x.chunk(2, dim=-1) - s = self.sigmoid(s) - x = x * s - # (time, batch, channels) - - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - - if src_key_padding_mask is not None: - x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) - - x, cache = self.depthwise_conv.streaming_forward(x, cache=cache) - - x = x.permute(2, 0, 1) # (time, batch, channels) - - x = self.out_proj(x) # (time, batch, channels) - - return x, cache - - -class ScalarMultiply(nn.Module): - def __init__(self, scale: float): - super().__init__() - self.scale = scale - - def forward(self, x): - return x * self.scale - - -def _test_zipformer_main(causal: bool = False): - batch_size = 5 - seq_len = 20 - # Just make sure the forward pass runs. - - c = Zipformer2( - encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4), - causal=causal, - chunk_size=(4,) if causal else (-1,), - left_context_frames=(64,) - ) - batch_size = 5 - seq_len = 20 - # Just make sure the forward pass runs. - f = c( - torch.randn(seq_len, batch_size, 64), - torch.full((batch_size,), seq_len, dtype=torch.int64), - ) - f[0].sum().backward() - c.eval() - f = c( - torch.randn(seq_len, batch_size, 64), - torch.full((batch_size,), seq_len, dtype=torch.int64), - ) - f # to remove flake8 warnings - - -if __name__ == "__main__": - logging.getLogger().setLevel(logging.INFO) - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - _test_zipformer_main(False) - _test_zipformer_main(True) diff --git a/egs/iwslt22_ta/ST/zipformer/zipformer.py b/egs/iwslt22_ta/ST/zipformer/zipformer.py new file mode 120000 index 000000000..23011dda7 --- /dev/null +++ b/egs/iwslt22_ta/ST/zipformer/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file diff --git a/egs/librispeech/ASR/local/compute_fbank_musan.py b/egs/librispeech/ASR/local/compute_fbank_musan.py old mode 100755 new mode 100644 diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 42ae9b9f2..e7ea652ca 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Copyright 2021-2023 Xiaomi Corp. (authors: Amir Hussein +# Fangjun Kuang, # Wei Kang, # Mingshuang Luo, # Zengwei Yao,