From 60317120caef28a43ed904a26263c57536ca95ab Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 17 Nov 2022 20:19:32 +0800 Subject: [PATCH] Revert "Apply new Black style changes" --- .git-blame-ignore-revs | 2 - .github/workflows/style_check.yml | 11 +- .pre-commit-config.yaml | 28 +- docker/README.md | 24 +- .../Dockerfile | 14 +- .../Dockerfile | 17 +- .../images/k2-gt-v1.9-blueviolet.svg | 2 +- .../images/python-gt-v3.6-blue.svg | 2 +- .../images/torch-gt-v1.6.0-green.svg | 2 +- docs/source/recipes/aishell/index.rst | 1 + docs/source/recipes/timit/index.rst | 1 + docs/source/recipes/timit/tdnn_ligru_ctc.rst | 28 +- docs/source/recipes/timit/tdnn_lstm_ctc.rst | 24 +- .../local/compute_fbank_aidatatang_200zh.py | 8 +- .../ASR/local/prepare_char.py | 8 +- .../ASR/local/prepare_lang.py | 4 +- .../ASR/local/test_prepare_lang.py | 4 +- egs/aidatatang_200zh/ASR/local/text2token.py | 21 +- egs/aidatatang_200zh/ASR/prepare.sh | 3 +- .../asr_datamodule.py | 110 +- .../pruned_transducer_stateless2/decode.py | 50 +- .../pruned_transducer_stateless2/export.py | 20 +- .../pretrained.py | 41 +- .../ASR/pruned_transducer_stateless2/train.py | 50 +- egs/aishell/ASR/conformer_ctc/conformer.py | 70 +- egs/aishell/ASR/conformer_ctc/decode.py | 29 +- egs/aishell/ASR/conformer_ctc/export.py | 17 +- egs/aishell/ASR/conformer_ctc/pretrained.py | 39 +- egs/aishell/ASR/conformer_ctc/subsampling.py | 16 +- .../ASR/conformer_ctc/test_subsampling.py | 3 +- egs/aishell/ASR/conformer_ctc/train.py | 12 +- egs/aishell/ASR/conformer_ctc/transformer.py | 44 +- egs/aishell/ASR/conformer_mmi/conformer.py | 70 +- egs/aishell/ASR/conformer_mmi/decode.py | 33 +- egs/aishell/ASR/conformer_mmi/subsampling.py | 16 +- egs/aishell/ASR/conformer_mmi/train.py | 8 +- egs/aishell/ASR/conformer_mmi/transformer.py | 44 +- .../local/compute_fbank_aidatatang_200zh.py | 8 +- .../ASR/local/compute_fbank_aishell.py | 8 +- egs/aishell/ASR/local/prepare_char.py | 8 +- egs/aishell/ASR/local/prepare_lang.py | 4 +- egs/aishell/ASR/local/test_prepare_lang.py | 4 +- .../pruned_transducer_stateless2/decode.py | 50 +- .../pruned_transducer_stateless2/export.py | 31 +- .../pretrained.py | 50 +- .../ASR/pruned_transducer_stateless2/train.py | 64 +- .../pruned_transducer_stateless3/decode.py | 73 +- .../pruned_transducer_stateless3/export.py | 54 +- .../ASR/pruned_transducer_stateless3/model.py | 8 +- .../pretrained.py | 50 +- .../ASR/pruned_transducer_stateless3/train.py | 79 +- .../ASR/tdnn_lstm_ctc/asr_datamodule.py | 118 +- egs/aishell/ASR/tdnn_lstm_ctc/decode.py | 33 +- egs/aishell/ASR/tdnn_lstm_ctc/model.py | 5 +- egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py | 37 +- egs/aishell/ASR/tdnn_lstm_ctc/train.py | 7 +- .../ASR/transducer_stateless/beam_search.py | 22 +- .../ASR/transducer_stateless/conformer.py | 70 +- .../ASR/transducer_stateless/decode.py | 39 +- .../ASR/transducer_stateless/decoder.py | 4 +- .../ASR/transducer_stateless/export.py | 20 +- egs/aishell/ASR/transducer_stateless/model.py | 4 +- .../ASR/transducer_stateless/pretrained.py | 36 +- egs/aishell/ASR/transducer_stateless/train.py | 15 +- .../ASR/transducer_stateless/transformer.py | 4 +- .../asr_datamodule.py | 85 +- .../transducer_stateless_modified-2/decode.py | 46 +- .../transducer_stateless_modified-2/export.py | 20 +- .../pretrained.py | 50 +- .../transducer_stateless_modified-2/train.py | 22 +- .../transducer_stateless_modified/decode.py | 46 +- .../transducer_stateless_modified/export.py | 20 +- .../pretrained.py | 50 +- .../transducer_stateless_modified/train.py | 15 +- egs/aishell2/ASR/local/__init__.py | 0 .../ASR/local/compute_fbank_aishell2.py | 8 +- .../pruned_transducer_stateless5/__init__.py | 0 .../asr_datamodule.py | 114 +- .../pruned_transducer_stateless5/decode.py | 67 +- .../pruned_transducer_stateless5/export.py | 47 +- .../pretrained.py | 40 +- .../ASR/pruned_transducer_stateless5/train.py | 67 +- .../ASR/local/compute_fbank_aishell4.py | 8 +- egs/aishell4/ASR/local/prepare_char.py | 8 +- egs/aishell4/ASR/local/prepare_lang.py | 4 +- egs/aishell4/ASR/local/test_prepare_lang.py | 4 +- egs/aishell4/ASR/local/text2token.py | 21 +- .../asr_datamodule.py | 110 +- .../pruned_transducer_stateless5/decode.py | 69 +- .../pruned_transducer_stateless5/export.py | 47 +- .../pretrained.py | 45 +- .../ASR/pruned_transducer_stateless5/train.py | 59 +- .../ASR/local/compute_fbank_alimeeting.py | 8 +- egs/alimeeting/ASR/local/prepare_char.py | 8 +- egs/alimeeting/ASR/local/prepare_lang.py | 4 +- egs/alimeeting/ASR/local/test_prepare_lang.py | 4 +- egs/alimeeting/ASR/local/text2segments.py | 2 +- egs/alimeeting/ASR/local/text2token.py | 21 +- .../asr_datamodule.py | 110 +- .../pruned_transducer_stateless2/decode.py | 60 +- .../pruned_transducer_stateless2/export.py | 20 +- .../pretrained.py | 41 +- .../ASR/pruned_transducer_stateless2/train.py | 50 +- egs/csj/ASR/.gitignore | 2 +- egs/csj/ASR/local/compute_fbank_csj.py | 38 +- egs/csj/ASR/local/compute_fbank_musan.py | 17 +- egs/csj/ASR/local/conf/disfluent.ini | 55 +- egs/csj/ASR/local/conf/fluent.ini | 55 +- egs/csj/ASR/local/conf/number.ini | 55 +- egs/csj/ASR/local/conf/symbol.ini | 55 +- .../ASR/local/display_manifest_statistics.py | 4 +- egs/csj/ASR/local/prepare_lang_char.py | 17 +- egs/csj/ASR/local/validate_manifest.py | 7 +- .../ASR/conformer_ctc/asr_datamodule.py | 117 +- egs/gigaspeech/ASR/conformer_ctc/conformer.py | 66 +- egs/gigaspeech/ASR/conformer_ctc/decode.py | 29 +- .../ASR/conformer_ctc/gigaspeech_scoring.py | 3 +- .../ASR/conformer_ctc/label_smoothing.py | 7 +- .../ASR/conformer_ctc/subsampling.py | 16 +- egs/gigaspeech/ASR/conformer_ctc/train.py | 12 +- .../ASR/conformer_ctc/transformer.py | 49 +- .../compute_fbank_gigaspeech_dev_test.py | 4 +- .../local/compute_fbank_gigaspeech_splits.py | 10 +- .../ASR/local/preprocess_gigaspeech.py | 10 +- .../asr_datamodule.py | 117 +- .../pruned_transducer_stateless2/decode.py | 42 +- .../pruned_transducer_stateless2/export.py | 24 +- .../ASR/pruned_transducer_stateless2/train.py | 48 +- egs/librispeech/ASR/conformer_ctc/ali.py | 25 +- .../ASR/conformer_ctc/conformer.py | 66 +- egs/librispeech/ASR/conformer_ctc/decode.py | 29 +- egs/librispeech/ASR/conformer_ctc/export.py | 17 +- .../ASR/conformer_ctc/label_smoothing.py | 7 +- .../ASR/conformer_ctc/pretrained.py | 33 +- .../ASR/conformer_ctc/subsampling.py | 16 +- egs/librispeech/ASR/conformer_ctc/train.py | 22 +- .../ASR/conformer_ctc/transformer.py | 49 +- .../ASR/conformer_ctc2/attention.py | 19 +- .../ASR/conformer_ctc2/conformer.py | 65 +- egs/librispeech/ASR/conformer_ctc2/decode.py | 56 +- egs/librispeech/ASR/conformer_ctc2/export.py | 49 +- egs/librispeech/ASR/conformer_ctc2/train.py | 39 +- .../ASR/conformer_ctc2/transformer.py | 46 +- .../ASR/conformer_mmi/conformer.py | 70 +- egs/librispeech/ASR/conformer_mmi/decode.py | 29 +- .../ASR/conformer_mmi/subsampling.py | 16 +- .../ASR/conformer_mmi/test_subsampling.py | 3 +- .../ASR/conformer_mmi/test_transformer.py | 9 +- .../ASR/conformer_mmi/train-with-attention.py | 27 +- egs/librispeech/ASR/conformer_mmi/train.py | 27 +- .../ASR/conformer_mmi/transformer.py | 28 +- .../decode.py | 69 +- .../emformer.py | 119 +- .../export.py | 47 +- .../stream.py | 8 +- .../streaming_decode.py | 75 +- .../train.py | 56 +- .../decode.py | 69 +- .../emformer.py | 108 +- .../export.py | 47 +- .../streaming_decode.py | 75 +- .../train.py | 56 +- .../ASR/local/add_alignment_librispeech.py | 12 +- egs/librispeech/ASR/local/compile_hlg.py | 6 +- egs/librispeech/ASR/local/compile_lg.py | 4 +- .../compute_fbank_gigaspeech_dev_test.py | 4 +- .../local/compute_fbank_gigaspeech_splits.py | 10 +- .../ASR/local/compute_fbank_librispeech.py | 8 +- .../ASR/local/compute_fbank_musan.py | 8 +- .../convert_transcript_words_to_tokens.py | 16 +- egs/librispeech/ASR/local/download_lm.py | 4 +- egs/librispeech/ASR/local/filter_cuts.py | 10 +- .../ASR/local/generate_unique_lexicon.py | 4 +- egs/librispeech/ASR/local/prepare_lang_bpe.py | 4 +- .../ASR/local/prepare_lm_training_data.py | 11 +- .../ASR/local/preprocess_gigaspeech.py | 4 +- .../ASR/local/test_prepare_lang.py | 4 +- .../ASR/local/validate_manifest.py | 7 +- .../ASR/lstm_transducer_stateless/decode.py | 818 ++++++++++++ .../ASR/lstm_transducer_stateless/export.py | 388 ++++++ .../jit_pretrained.py | 322 +++++ .../ASR/lstm_transducer_stateless/lstm.py | 871 +++++++++++++ .../ASR/lstm_transducer_stateless/model.py | 210 +++ .../lstm_transducer_stateless/pretrained.py | 352 +++++ .../ASR/lstm_transducer_stateless/stream.py | 148 +++ .../streaming_decode.py | 968 ++++++++++++++ .../ASR/lstm_transducer_stateless/train.py | 1157 +++++++++++++++++ .../ASR/lstm_transducer_stateless2/decode.py | 67 +- .../ASR/lstm_transducer_stateless2/export.py | 59 +- .../jit_pretrained.py | 21 +- .../ASR/lstm_transducer_stateless2/model.py | 8 +- .../lstm_transducer_stateless2/ncnn-decode.py | 15 +- .../lstm_transducer_stateless2/pretrained.py | 40 +- .../streaming-ncnn-decode.py | 27 +- .../streaming-onnx-decode.py | 45 +- .../ASR/lstm_transducer_stateless2/train.py | 68 +- .../ASR/lstm_transducer_stateless3/decode.py | 79 +- .../ASR/lstm_transducer_stateless3/export.py | 47 +- .../jit_pretrained.py | 21 +- .../ASR/lstm_transducer_stateless3/lstm.py | 14 +- .../lstm_transducer_stateless3/pretrained.py | 40 +- .../streaming_decode.py | 74 +- .../ASR/lstm_transducer_stateless3/train.py | 66 +- .../ASR/pruned2_knowledge/asr_datamodule.py | 125 +- .../ASR/pruned2_knowledge/beam_search.py | 18 +- .../ASR/pruned2_knowledge/conformer.py | 90 +- .../ASR/pruned2_knowledge/decode.py | 44 +- .../ASR/pruned2_knowledge/decoder.py | 4 +- .../ASR/pruned2_knowledge/decoder2.py | 84 +- .../ASR/pruned2_knowledge/export.py | 20 +- .../ASR/pruned2_knowledge/joiner.py | 4 +- .../ASR/pruned2_knowledge/model.py | 8 +- .../ASR/pruned2_knowledge/optim.py | 35 +- .../ASR/pruned2_knowledge/sampling.py | 180 ++- .../ASR/pruned2_knowledge/scaling.py | 51 +- .../ASR/pruned2_knowledge/scaling_tmp.py | 355 ++--- .../ASR/pruned2_knowledge/train.py | 50 +- .../pruned_stateless_emformer_rnnt2/decode.py | 69 +- .../emformer.py | 8 +- .../pruned_stateless_emformer_rnnt2/export.py | 47 +- .../pruned_stateless_emformer_rnnt2/model.py | 4 +- .../pruned_stateless_emformer_rnnt2/train.py | 44 +- .../beam_search.py | 26 +- .../ASR/pruned_transducer_stateless/decode.py | 44 +- .../decode_stream.py | 19 +- .../pruned_transducer_stateless/decoder.py | 4 +- .../ASR/pruned_transducer_stateless/export.py | 20 +- .../ASR/pruned_transducer_stateless/model.py | 4 +- .../pruned_transducer_stateless/pretrained.py | 36 +- .../streaming_beam_search.py | 8 +- .../streaming_decode.py | 39 +- .../ASR/pruned_transducer_stateless/train.py | 46 +- .../beam_search.py | 51 +- .../pruned_transducer_stateless2/conformer.py | 97 +- .../pruned_transducer_stateless2/decode.py | 50 +- .../pruned_transducer_stateless2/decoder.py | 8 +- .../pruned_transducer_stateless2/export.py | 24 +- .../pruned_transducer_stateless2/joiner.py | 4 +- .../ASR/pruned_transducer_stateless2/model.py | 8 +- .../ASR/pruned_transducer_stateless2/optim.py | 35 +- .../pretrained.py | 36 +- .../pruned_transducer_stateless2/scaling.py | 56 +- .../streaming_beam_search.py | 12 +- .../streaming_decode.py | 39 +- .../ASR/pruned_transducer_stateless2/train.py | 58 +- .../asr_datamodule.py | 85 +- .../decode-giga.py | 54 +- .../pruned_transducer_stateless3/decode.py | 74 +- .../pruned_transducer_stateless3/export.py | 32 +- .../gigaspeech.py | 8 +- .../jit_pretrained.py | 21 +- .../ASR/pruned_transducer_stateless3/model.py | 8 +- .../onnx_check.py | 24 +- .../onnx_pretrained.py | 27 +- .../pretrained.py | 36 +- .../scaling_converter.py | 10 +- .../streaming_decode.py | 39 +- .../pruned_transducer_stateless3/test_onnx.py | 24 +- .../ASR/pruned_transducer_stateless3/train.py | 65 +- .../pruned_transducer_stateless4/decode.py | 79 +- .../pruned_transducer_stateless4/export.py | 47 +- .../streaming_decode.py | 62 +- .../ASR/pruned_transducer_stateless4/train.py | 61 +- .../pruned_transducer_stateless5/conformer.py | 118 +- .../pruned_transducer_stateless5/decode.py | 67 +- .../pruned_transducer_stateless5/export.py | 47 +- .../pretrained.py | 40 +- .../streaming_decode.py | 62 +- .../ASR/pruned_transducer_stateless5/train.py | 66 +- .../pruned_transducer_stateless6/conformer.py | 67 +- .../pruned_transducer_stateless6/decode.py | 69 +- .../pruned_transducer_stateless6/export.py | 24 +- .../extract_codebook_index.py | 3 +- .../hubert_decode.py | 17 +- .../hubert_xlarge.py | 22 +- .../ASR/pruned_transducer_stateless6/model.py | 12 +- .../ASR/pruned_transducer_stateless6/train.py | 65 +- .../pruned_transducer_stateless6/vq_utils.py | 31 +- .../pruned_transducer_stateless7/decode.py | 67 +- .../pruned_transducer_stateless7/decoder.py | 6 +- .../pruned_transducer_stateless7/export.py | 47 +- .../jit_pretrained.py | 21 +- .../pruned_transducer_stateless7/joiner.py | 4 +- .../ASR/pruned_transducer_stateless7/model.py | 16 +- .../ASR/pruned_transducer_stateless7/optim.py | 435 +++---- .../pretrained.py | 40 +- .../pruned_transducer_stateless7/scaling.py | 487 ++++--- .../scaling_converter.py | 12 +- .../ASR/pruned_transducer_stateless7/train.py | 88 +- .../pruned_transducer_stateless7/zipformer.py | 654 +++++----- .../pruned_transducer_stateless8/decode.py | 67 +- .../pruned_transducer_stateless8/export.py | 47 +- .../jit_pretrained.py | 21 +- .../ASR/pruned_transducer_stateless8/model.py | 4 +- .../pretrained.py | 40 +- .../ASR/pruned_transducer_stateless8/train.py | 99 +- .../ASR/streaming_conformer_ctc/README.md | 16 +- .../ASR/streaming_conformer_ctc/conformer.py | 116 +- .../streaming_decode.py | 68 +- .../ASR/streaming_conformer_ctc/train.py | 16 +- .../streaming_conformer_ctc/transformer.py | 40 +- .../ASR/tdnn_lstm_ctc/asr_datamodule.py | 113 +- egs/librispeech/ASR/tdnn_lstm_ctc/decode.py | 29 +- egs/librispeech/ASR/tdnn_lstm_ctc/model.py | 5 +- .../ASR/tdnn_lstm_ctc/pretrained.py | 43 +- egs/librispeech/ASR/tdnn_lstm_ctc/train.py | 8 +- egs/librispeech/ASR/transducer/beam_search.py | 14 +- egs/librispeech/ASR/transducer/decode.py | 28 +- egs/librispeech/ASR/transducer/export.py | 17 +- egs/librispeech/ASR/transducer/pretrained.py | 33 +- egs/librispeech/ASR/transducer/rnn.py | 24 +- egs/librispeech/ASR/transducer/test_rnn.py | 16 +- egs/librispeech/ASR/transducer/train.py | 12 +- .../ASR/transducer_lstm/beam_search.py | 14 +- egs/librispeech/ASR/transducer_lstm/decode.py | 28 +- .../ASR/transducer_lstm/encoder.py | 4 +- egs/librispeech/ASR/transducer_lstm/train.py | 12 +- .../ASR/transducer_stateless/alignment.py | 4 +- .../ASR/transducer_stateless/beam_search.py | 28 +- .../ASR/transducer_stateless/compute_ali.py | 24 +- .../ASR/transducer_stateless/conformer.py | 107 +- .../ASR/transducer_stateless/decode.py | 42 +- .../ASR/transducer_stateless/decoder.py | 4 +- .../ASR/transducer_stateless/export.py | 20 +- .../ASR/transducer_stateless/joiner.py | 8 +- .../ASR/transducer_stateless/pretrained.py | 36 +- .../transducer_stateless/test_compute_ali.py | 11 +- .../transducer_stateless/test_conformer.py | 4 +- .../ASR/transducer_stateless/train.py | 23 +- .../ASR/transducer_stateless/transformer.py | 4 +- .../ASR/transducer_stateless2/decode.py | 42 +- .../ASR/transducer_stateless2/export.py | 20 +- .../ASR/transducer_stateless2/pretrained.py | 36 +- .../ASR/transducer_stateless2/train.py | 23 +- .../decode.py | 42 +- .../export.py | 20 +- .../pretrained.py | 36 +- .../test_asr_datamodule.py | 4 +- .../train.py | 22 +- egs/ptb/LM/local/sort_lm_training_data.py | 4 +- .../LM/local/test_prepare_lm_training_data.py | 4 +- .../ASR/local/compute_fbank_musan.py | 8 +- .../ASR/local/compute_fbank_spgispeech.py | 14 +- egs/spgispeech/ASR/local/prepare_splits.py | 8 +- .../asr_datamodule.py | 100 +- .../pruned_transducer_stateless2/decode.py | 66 +- .../pruned_transducer_stateless2/export.py | 26 +- .../ASR/pruned_transducer_stateless2/train.py | 51 +- .../ASR/local/compute_fbank_tal_csasr.py | 8 +- egs/tal_csasr/ASR/local/prepare_char.py | 4 +- egs/tal_csasr/ASR/local/prepare_lang.py | 4 +- egs/tal_csasr/ASR/local/test_prepare_lang.py | 4 +- egs/tal_csasr/ASR/local/text2token.py | 21 +- .../asr_datamodule.py | 110 +- .../pruned_transducer_stateless5/decode.py | 77 +- .../pruned_transducer_stateless5/export.py | 47 +- .../pretrained.py | 40 +- .../ASR/pruned_transducer_stateless5/train.py | 59 +- .../ASR/local/compute_fbank_tedlium.py | 8 +- .../convert_transcript_words_to_bpe_ids.py | 4 +- egs/tedlium3/ASR/local/prepare_lexicon.py | 11 +- egs/tedlium3/ASR/local/prepare_transcripts.py | 11 +- .../ASR/pruned_transducer_stateless/decode.py | 38 +- .../ASR/pruned_transducer_stateless/export.py | 20 +- .../pruned_transducer_stateless/pretrained.py | 41 +- .../ASR/pruned_transducer_stateless/train.py | 35 +- .../transducer_stateless/asr_datamodule.py | 127 +- .../ASR/transducer_stateless/beam_search.py | 30 +- .../ASR/transducer_stateless/decode.py | 31 +- .../ASR/transducer_stateless/decoder.py | 4 +- .../ASR/transducer_stateless/export.py | 20 +- .../ASR/transducer_stateless/pretrained.py | 36 +- .../ASR/transducer_stateless/train.py | 11 +- egs/timit/ASR/RESULTS.md | 2 +- egs/timit/ASR/local/compile_hlg.py | 4 +- egs/timit/ASR/local/compute_fbank_timit.py | 8 +- egs/timit/ASR/local/prepare_lexicon.py | 8 +- egs/timit/ASR/prepare.sh | 4 +- egs/timit/ASR/tdnn_ligru_ctc/decode.py | 29 +- egs/timit/ASR/tdnn_ligru_ctc/model.py | 12 +- egs/timit/ASR/tdnn_ligru_ctc/pretrained.py | 43 +- egs/timit/ASR/tdnn_ligru_ctc/train.py | 4 +- egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py | 104 +- egs/timit/ASR/tdnn_lstm_ctc/decode.py | 29 +- egs/timit/ASR/tdnn_lstm_ctc/model.py | 5 +- egs/timit/ASR/tdnn_lstm_ctc/pretrained.py | 43 +- egs/timit/ASR/tdnn_lstm_ctc/train.py | 4 +- .../compute_fbank_wenetspeech_dev_test.py | 11 +- .../local/compute_fbank_wenetspeech_splits.py | 10 +- egs/wenetspeech/ASR/local/prepare_char.py | 8 +- .../ASR/local/preprocess_wenetspeech.py | 6 +- egs/wenetspeech/ASR/local/text2token.py | 21 +- egs/wenetspeech/ASR/prepare.sh | 2 +- .../asr_datamodule.py | 121 +- .../pruned_transducer_stateless2/decode.py | 64 +- .../pruned_transducer_stateless2/export.py | 28 +- .../jit_pretrained.py | 21 +- .../onnx_check.py | 24 +- .../onnx_pretrained.py | 27 +- .../pretrained.py | 41 +- .../ASR/pruned_transducer_stateless2/train.py | 50 +- .../pruned_transducer_stateless5/conformer.py | 97 +- .../pruned_transducer_stateless5/decode.py | 75 +- .../decode_stream.py | 19 +- .../pruned_transducer_stateless5/export.py | 20 +- .../pretrained.py | 41 +- .../streaming_beam_search.py | 8 +- .../streaming_decode.py | 62 +- .../ASR/pruned_transducer_stateless5/train.py | 67 +- egs/yesno/ASR/local/compile_hlg.py | 4 +- egs/yesno/ASR/local/compute_fbank_yesno.py | 12 +- egs/yesno/ASR/tdnn/asr_datamodule.py | 74 +- egs/yesno/ASR/tdnn/decode.py | 29 +- egs/yesno/ASR/tdnn/pretrained.py | 37 +- egs/yesno/ASR/tdnn/train.py | 4 +- egs/yesno/ASR/transducer/decode.py | 25 +- egs/yesno/ASR/transducer/train.py | 4 +- icefall/char_graph_compiler.py | 8 +- icefall/checkpoint.py | 12 +- icefall/decode.py | 40 +- icefall/diagnostics.py | 80 +- icefall/dist.py | 4 +- icefall/env.py | 4 +- icefall/graph_compiler.py | 4 +- icefall/hooks.py | 19 +- icefall/lexicon.py | 16 +- icefall/mmi.py | 29 +- icefall/mmi_graph_compiler.py | 8 +- icefall/rnn_lm/compute_perplexity.py | 15 +- icefall/rnn_lm/dataset.py | 8 +- icefall/rnn_lm/export.py | 17 +- icefall/rnn_lm/model.py | 28 +- icefall/rnn_lm/train.py | 11 +- icefall/shared/make_kn_lm.py | 184 +-- icefall/utils.py | 66 +- pyproject.toml | 2 +- setup.py | 3 +- test/test_checkpoint.py | 6 +- test/test_decode.py | 1 - test/test_graph_compiler.py | 4 +- test/test_utils.py | 4 +- 441 files changed, 14535 insertions(+), 6789 deletions(-) delete mode 100644 .git-blame-ignore-revs mode change 100644 => 100755 egs/aishell2/ASR/local/__init__.py mode change 100644 => 100755 egs/aishell2/ASR/pruned_transducer_stateless5/__init__.py mode change 100644 => 100755 egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py mode change 100644 => 100755 egs/librispeech/ASR/lstm_transducer_stateless/decode.py mode change 100644 => 100755 egs/librispeech/ASR/lstm_transducer_stateless/export.py mode change 100644 => 100755 egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py mode change 100644 => 100755 egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py mode change 100644 => 100755 egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py mode change 100644 => 100755 egs/librispeech/ASR/lstm_transducer_stateless/train.py diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs deleted file mode 100644 index c5908fc89..000000000 --- a/.git-blame-ignore-revs +++ /dev/null @@ -1,2 +0,0 @@ -# Migrate to 88 characters per line (see: https://github.com/lhotse-speech/lhotse/issues/890) -d110b04ad389134c82fa314e3aafc7b40043efb0 diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml index 45d261ccc..90459bc1c 100644 --- a/.github/workflows/style_check.yml +++ b/.github/workflows/style_check.yml @@ -45,18 +45,17 @@ jobs: - name: Install Python dependencies run: | - python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0 - # Click issue fixed in https://github.com/psf/black/pull/2966 + python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2 click==8.0.4 + # See https://github.com/psf/black/issues/2964 + # The version of click should be selected from 8.0.0, 8.0.1, 8.0.2, 8.0.3, and 8.0.4 - name: Run flake8 shell: bash working-directory: ${{github.workspace}} run: | # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 \ - --statistics --extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503 + flake8 . --count --show-source --statistics + flake8 . - name: Run black shell: bash diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5cb213327..446ba0fe7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,38 +1,26 @@ repos: - repo: https://github.com/psf/black - rev: 22.3.0 + rev: 21.6b0 hooks: - id: black - args: ["--line-length=88"] - additional_dependencies: ['click==8.1.0'] + args: [--line-length=80] + additional_dependencies: ['click==8.0.1'] exclude: icefall\/__init__\.py - repo: https://github.com/PyCQA/flake8 - rev: 5.0.4 + rev: 3.9.2 hooks: - id: flake8 - args: ["--max-line-length=88", "--extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503"] - - # What are we ignoring here? - # E203: whitespace before ':' - # E266: too many leading '#' for block comment - # E501: line too long - # F401: module imported but unused - # E402: module level import not at top of file - # F403: 'from module import *' used; unable to detect undefined names - # F841: local variable is assigned to but never used - # W503: line break before binary operator - # In addition, the default ignore list is: - # E121,E123,E126,E226,E24,E704,W503,W504 + args: [--max-line-length=80] - repo: https://github.com/pycqa/isort - rev: 5.10.1 + rev: 5.9.2 hooks: - id: isort - args: ["--profile=black"] + args: [--profile=black, --line-length=80] - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.2.0 + rev: v4.0.1 hooks: - id: check-executables-have-shebangs - id: end-of-file-fixer diff --git a/docker/README.md b/docker/README.md index c14b9bf75..6f2314e96 100644 --- a/docker/README.md +++ b/docker/README.md @@ -2,7 +2,7 @@ 2 sets of configuration are provided - (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8, and (b) Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8. -If your NVIDIA driver supports CUDA Version: 11.3, please go for case (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8. +If your NVIDIA driver supports CUDA Version: 11.3, please go for case (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8. Otherwise, since the older PyTorch images are not updated with the [apt-key rotation by NVIDIA](https://developer.nvidia.com/blog/updating-the-cuda-linux-gpg-repository-key), you have to go for case (b) Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8. Ensure that your NVDIA driver supports at least CUDA 11.0. @@ -10,7 +10,7 @@ You can check the highest CUDA version within your NVIDIA driver's support with ```bash $ nvidia-smi -Tue Sep 20 00:26:13 2022 +Tue Sep 20 00:26:13 2022 +-----------------------------------------------------------------------------+ | NVIDIA-SMI 450.119.03 Driver Version: 450.119.03 CUDA Version: 11.0 | |-------------------------------+----------------------+----------------------+ @@ -26,7 +26,7 @@ Tue Sep 20 00:26:13 2022 | 41% 30C P8 11W / 280W | 6MiB / 24220MiB | 0% Default | | | | N/A | +-------------------------------+----------------------+----------------------+ - + +-----------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | @@ -40,15 +40,15 @@ Tue Sep 20 00:26:13 2022 ``` ## Building images locally -If your environment requires a proxy to access the Internet, remember to add those information into the Dockerfile directly. -For most cases, you can uncomment these lines in the Dockerfile and add in your proxy details. +If your environment requires a proxy to access the Internet, remember to add those information into the Dockerfile directly. +For most cases, you can uncomment these lines in the Dockerfile and add in your proxy details. ```dockerfile ENV http_proxy=http://aaa.bb.cc.net:8080 \ https_proxy=http://aaa.bb.cc.net:8080 ``` -Then, proceed with these commands. +Then, proceed with these commands. ### If you are case (a), i.e. your NVIDIA driver supports CUDA version >= 11.3: @@ -72,11 +72,11 @@ docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all icefall ``` ### Tips: -1. Since your data and models most probably won't be in the docker, you must use the -v flag to access the host machine. Do this by specifying `-v {/path/in/host/machine}:{/path/in/docker}`. +1. Since your data and models most probably won't be in the docker, you must use the -v flag to access the host machine. Do this by specifying `-v {/path/in/host/machine}:{/path/in/docker}`. 2. Also, if your environment requires a proxy, this would be a good time to add it in too: `-e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080`. -Overall, your docker run command should look like this. +Overall, your docker run command should look like this. ```bash docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all -v {/path/in/host/machine}:{/path/in/docker} -e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080 icefall/pytorch1.12.1 @@ -86,9 +86,9 @@ You can explore more docker run options [here](https://docs.docker.com/engine/re ### Linking to icefall in your host machine -If you already have icefall downloaded onto your host machine, you can use that repository instead so that changes in your code are visible inside and outside of the container. +If you already have icefall downloaded onto your host machine, you can use that repository instead so that changes in your code are visible inside and outside of the container. -Note: Remember to set the -v flag above during the first run of the container, as that is the only way for your container to access your host machine. +Note: Remember to set the -v flag above during the first run of the container, as that is the only way for your container to access your host machine. Warning: Check that the icefall in your host machine is visible from within your container before proceeding to the commands below. Use these commands once you are inside the container. @@ -103,7 +103,7 @@ ln -s {/path/in/docker/to/icefall} /workspace/icefall docker exec -it icefall /bin/bash ``` -## Restarting a killed container that has been run before. +## Restarting a killed container that has been run before. ```bash docker start -ai icefall ``` @@ -111,4 +111,4 @@ docker start -ai icefall ## Sample usage of the CPU based images: ```bash docker run -it icefall /bin/bash -``` +``` diff --git a/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile b/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile index ff9e40604..3637d2f11 100644 --- a/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile +++ b/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile @@ -1,7 +1,7 @@ FROM pytorch/pytorch:1.12.1-cuda11.3-cudnn8-devel # ENV http_proxy=http://aaa.bbb.cc.net:8080 \ -# https_proxy=http://aaa.bbb.cc.net:8080 +# https_proxy=http://aaa.bbb.cc.net:8080 # install normal source RUN apt-get update && \ @@ -38,10 +38,10 @@ RUN wget -P /opt https://cmake.org/files/v3.18/cmake-3.18.0.tar.gz && \ rm -rf cmake-3.18.0.tar.gz && \ find /opt/cmake-3.18.0 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \ cd - - -# flac + +# flac RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz && \ - cd /opt && \ + cd /opt && \ xz -d flac-1.3.2.tar.xz && \ tar -xvf flac-1.3.2.tar && \ cd flac-1.3.2 && \ @@ -49,11 +49,11 @@ RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz && make && make install && \ rm -rf flac-1.3.2.tar && \ find /opt/flac-1.3.2 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \ - cd - + cd - RUN conda install -y -c pytorch torchaudio=0.12 && \ pip install graphviz - + #install k2 from source RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \ @@ -68,7 +68,7 @@ RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ cd /workspace/icefall && \ pip install -r requirements.txt -RUN pip install kaldifeat +RUN pip install kaldifeat ENV PYTHONPATH /workspace/icefall:$PYTHONPATH WORKDIR /workspace/icefall diff --git a/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile b/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile index 5c7423fa5..17a8215f9 100644 --- a/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile +++ b/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile @@ -1,12 +1,12 @@ FROM pytorch/pytorch:1.7.1-cuda11.0-cudnn8-devel # ENV http_proxy=http://aaa.bbb.cc.net:8080 \ -# https_proxy=http://aaa.bbb.cc.net:8080 +# https_proxy=http://aaa.bbb.cc.net:8080 RUN rm /etc/apt/sources.list.d/cuda.list && \ rm /etc/apt/sources.list.d/nvidia-ml.list && \ apt-key del 7fa2af80 - + # install normal source RUN apt-get update && \ apt-get install -y --no-install-recommends \ @@ -36,7 +36,7 @@ RUN curl -fsSL https://developer.download.nvidia.com/compute/cuda/repos/ubuntu18 curl -fsSL https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub | apt-key add - && \ echo "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/cuda.list && \ echo "deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/nvidia-ml.list && \ - rm -rf /var/lib/apt/lists/* && \ + rm -rf /var/lib/apt/lists/* && \ mv /opt/conda/lib/libcufft.so.10 /opt/libcufft.so.10.bak && \ mv /opt/conda/lib/libcurand.so.10 /opt/libcurand.so.10.bak && \ mv /opt/conda/lib/libcublas.so.11 /opt/libcublas.so.11.bak && \ @@ -56,10 +56,10 @@ RUN wget -P /opt https://cmake.org/files/v3.18/cmake-3.18.0.tar.gz && \ rm -rf cmake-3.18.0.tar.gz && \ find /opt/cmake-3.18.0 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \ cd - - -# flac + +# flac RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz && \ - cd /opt && \ + cd /opt && \ xz -d flac-1.3.2.tar.xz && \ tar -xvf flac-1.3.2.tar && \ cd flac-1.3.2 && \ @@ -67,7 +67,7 @@ RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz && make && make install && \ rm -rf flac-1.3.2.tar && \ find /opt/flac-1.3.2 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \ - cd - + cd - RUN conda install -y -c pytorch torchaudio=0.7.1 && \ pip install graphviz @@ -79,7 +79,7 @@ RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \ cd - # install lhotse -RUN pip install git+https://github.com/lhotse-speech/lhotse +RUN pip install git+https://github.com/lhotse-speech/lhotse RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ cd /workspace/icefall && \ @@ -88,3 +88,4 @@ RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ ENV PYTHONPATH /workspace/icefall:$PYTHONPATH WORKDIR /workspace/icefall + diff --git a/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg b/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg index 3019ff03d..534b2e534 100644 --- a/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg +++ b/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg @@ -1 +1 @@ -k2: >= v1.9k2>= v1.9 +k2: >= v1.9k2>= v1.9 \ No newline at end of file diff --git a/docs/source/installation/images/python-gt-v3.6-blue.svg b/docs/source/installation/images/python-gt-v3.6-blue.svg index df677ad09..4254dc58a 100644 --- a/docs/source/installation/images/python-gt-v3.6-blue.svg +++ b/docs/source/installation/images/python-gt-v3.6-blue.svg @@ -1 +1 @@ -python: >= 3.6python>= 3.6 +python: >= 3.6python>= 3.6 \ No newline at end of file diff --git a/docs/source/installation/images/torch-gt-v1.6.0-green.svg b/docs/source/installation/images/torch-gt-v1.6.0-green.svg index d7007d742..d3ece9a17 100644 --- a/docs/source/installation/images/torch-gt-v1.6.0-green.svg +++ b/docs/source/installation/images/torch-gt-v1.6.0-green.svg @@ -1 +1 @@ -torch: >= 1.6.0torch>= 1.6.0 +torch: >= 1.6.0torch>= 1.6.0 \ No newline at end of file diff --git a/docs/source/recipes/aishell/index.rst b/docs/source/recipes/aishell/index.rst index b77d59bca..d072d6e9c 100644 --- a/docs/source/recipes/aishell/index.rst +++ b/docs/source/recipes/aishell/index.rst @@ -19,3 +19,4 @@ It can be downloaded from ``_ tdnn_lstm_ctc conformer_ctc stateless_transducer + diff --git a/docs/source/recipes/timit/index.rst b/docs/source/recipes/timit/index.rst index 5ee147be7..17f40cdb7 100644 --- a/docs/source/recipes/timit/index.rst +++ b/docs/source/recipes/timit/index.rst @@ -6,3 +6,4 @@ TIMIT tdnn_ligru_ctc tdnn_lstm_ctc + diff --git a/docs/source/recipes/timit/tdnn_ligru_ctc.rst b/docs/source/recipes/timit/tdnn_ligru_ctc.rst index 3d7aefe02..186420ee7 100644 --- a/docs/source/recipes/timit/tdnn_ligru_ctc.rst +++ b/docs/source/recipes/timit/tdnn_ligru_ctc.rst @@ -148,10 +148,10 @@ Some commonly used options are: $ ./tdnn_ligru_ctc/decode.py --epoch 25 --avg 17 - uses the average of ``epoch-9.pt``, ``epoch-10.pt``, ``epoch-11.pt``, - ``epoch-12.pt``, ``epoch-13.pt``, ``epoch-14.pt``, ``epoch-15.pt``, - ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``, ``epoch-19.pt``, - ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``, ``epoch-23.pt``, + uses the average of ``epoch-9.pt``, ``epoch-10.pt``, ``epoch-11.pt``, + ``epoch-12.pt``, ``epoch-13.pt``, ``epoch-14.pt``, ``epoch-15.pt``, + ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``, ``epoch-19.pt``, + ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``, ``epoch-23.pt``, ``epoch-24.pt`` and ``epoch-25.pt`` for decoding. @@ -317,13 +317,13 @@ To decode with ``1best`` method, we can use: .. code-block:: bash - ./tdnn_ligru_ctc/pretrained.py + ./tdnn_ligru_ctc/pretrained.py --method 1best - --checkpoint ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/exp/pretrained_average_9_25.pt - --words-file ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/words.txt - --HLG ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/HLG.pt - ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV - ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV + --checkpoint ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/exp/pretrained_average_9_25.pt + --words-file ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/words.txt + --HLG ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/HLG.pt + ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV + ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV The output is: @@ -337,7 +337,7 @@ The output is: 2021-11-08 20:41:38,697 INFO [pretrained.py:210] Reading sound files: ['./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV'] 2021-11-08 20:41:38,704 INFO [pretrained.py:216] Decoding started 2021-11-08 20:41:39,819 INFO [pretrained.py:246] Use HLG decoding - 2021-11-08 20:41:39,829 INFO [pretrained.py:267] + 2021-11-08 20:41:39,829 INFO [pretrained.py:267] ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV: sil dh ih sh uw ah l iy v iy z ih sil p r aa sil k s ih m ey dx ih sil d w uh dx ih w ih s f iy l ih ng w ih th ih n ih m s eh l f sil jh @@ -362,8 +362,8 @@ To decode with ``whole-lattice-rescoring`` methond, you can use --HLG ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/HLG.pt \ --G ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lm/G_4_gram.pt \ --ngram-lm-scale 0.1 \ - ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV - ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV + ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV + ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV The decoding output is: @@ -378,7 +378,7 @@ The decoding output is: 2021-11-08 20:37:54,715 INFO [pretrained.py:210] Reading sound files: ['./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV'] 2021-11-08 20:37:54,720 INFO [pretrained.py:216] Decoding started 2021-11-08 20:37:55,808 INFO [pretrained.py:251] Use HLG decoding + LM rescoring - 2021-11-08 20:37:56,348 INFO [pretrained.py:267] + 2021-11-08 20:37:56,348 INFO [pretrained.py:267] ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV: sil dh ih sh uw ah l iy v iy z ah sil p r aa sil k s ih m ey dx ih sil d w uh dx iy w ih s f iy l iy ng w ih th ih n ih m s eh l f sil jh diff --git a/docs/source/recipes/timit/tdnn_lstm_ctc.rst b/docs/source/recipes/timit/tdnn_lstm_ctc.rst index ee67a6edc..6f760a9ce 100644 --- a/docs/source/recipes/timit/tdnn_lstm_ctc.rst +++ b/docs/source/recipes/timit/tdnn_lstm_ctc.rst @@ -148,8 +148,8 @@ Some commonly used options are: $ ./tdnn_lstm_ctc/decode.py --epoch 25 --avg 10 - uses the average of ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``, - ``epoch-19.pt``, ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``, + uses the average of ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``, + ``epoch-19.pt``, ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``, ``epoch-23.pt``, ``epoch-24.pt`` and ``epoch-25.pt`` for decoding. @@ -315,13 +315,13 @@ To decode with ``1best`` method, we can use: .. code-block:: bash - ./tdnn_lstm_ctc/pretrained.py + ./tdnn_lstm_ctc/pretrained.py --method 1best - --checkpoint ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/exp/pretrained_average_16_25.pt - --words-file ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/words.txt - --HLG ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt - ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV - ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV + --checkpoint ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/exp/pretrained_average_16_25.pt + --words-file ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/words.txt + --HLG ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt + ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV + ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV The output is: @@ -335,7 +335,7 @@ The output is: 2021-11-08 21:02:53,827 INFO [pretrained.py:210] Reading sound files: ['./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV'] 2021-11-08 21:02:53,831 INFO [pretrained.py:216] Decoding started 2021-11-08 21:02:54,380 INFO [pretrained.py:246] Use HLG decoding - 2021-11-08 21:02:54,387 INFO [pretrained.py:267] + 2021-11-08 21:02:54,387 INFO [pretrained.py:267] ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV: sil dh ih sh uw ah l iy v iy z ih sil p r aa sil k s ih m ey dx ih sil d w uh dx iy w ih s f iy l iy w ih th ih n ih m s eh l f sil jh @@ -360,8 +360,8 @@ To decode with ``whole-lattice-rescoring`` methond, you can use --HLG ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt \ --G ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lm/G_4_gram.pt \ --ngram-lm-scale 0.08 \ - ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV - ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV + ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV + ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV The decoding output is: @@ -376,7 +376,7 @@ The decoding output is: 2021-11-08 20:05:26,978 INFO [pretrained.py:210] Reading sound files: ['./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV'] 2021-11-08 20:05:26,981 INFO [pretrained.py:216] Decoding started 2021-11-08 20:05:27,519 INFO [pretrained.py:251] Use HLG decoding + LM rescoring - 2021-11-08 20:05:27,878 INFO [pretrained.py:267] + 2021-11-08 20:05:27,878 INFO [pretrained.py:267] ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV: sil dh ih sh uw l iy v iy z ih sil p r aa sil k s ah m ey dx ih sil w uh dx iy w ih s f iy l ih ng w ih th ih n ih m s eh l f sil jh diff --git a/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py b/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py index 387c14acf..fb2751c0f 100755 --- a/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py +++ b/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py @@ -87,7 +87,9 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80): ) if "train" in partition: cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + cut_set + + cut_set.perturb_speed(0.9) + + cut_set.perturb_speed(1.1) ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -114,7 +116,9 @@ def get_args(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/aidatatang_200zh/ASR/local/prepare_char.py b/egs/aidatatang_200zh/ASR/local/prepare_char.py index 6b440dfb3..d9e47d17a 100755 --- a/egs/aidatatang_200zh/ASR/local/prepare_char.py +++ b/egs/aidatatang_200zh/ASR/local/prepare_char.py @@ -86,7 +86,9 @@ def lexicon_to_fst_no_sil( cur_state = loop_state word = word2id[word] - pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces] + pieces = [ + token2id[i] if i in token2id else token2id[""] for i in pieces + ] for i in range(len(pieces) - 1): w = word if i == 0 else eps @@ -140,7 +142,9 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool: return False -def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon: +def generate_lexicon( + token_sym_table: Dict[str, int], words: List[str] +) -> Lexicon: """Generate a lexicon from a word list and token_sym_table. Args: diff --git a/egs/aidatatang_200zh/ASR/local/prepare_lang.py b/egs/aidatatang_200zh/ASR/local/prepare_lang.py index c8cf9b881..e5ae89ec4 100755 --- a/egs/aidatatang_200zh/ASR/local/prepare_lang.py +++ b/egs/aidatatang_200zh/ASR/local/prepare_lang.py @@ -317,7 +317,9 @@ def lexicon_to_fst( def get_args(): parser = argparse.ArgumentParser() - parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone") + parser.add_argument( + "--lang-dir", type=str, help="The lang dir, data/lang_phone" + ) return parser.parse_args() diff --git a/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py b/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py index 74e025ad7..d4cf62bba 100755 --- a/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py +++ b/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py @@ -88,7 +88,9 @@ def test_read_lexicon(filename: str): fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.draw("L.pdf", title="L") - fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id) + fsa_disambig = lexicon_to_fst( + lexicon_disambig, phone2id=phone2id, word2id=word2id + ) fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.draw("L_disambig.pdf", title="L_disambig") diff --git a/egs/aidatatang_200zh/ASR/local/text2token.py b/egs/aidatatang_200zh/ASR/local/text2token.py index 2be639b7a..71be2a613 100755 --- a/egs/aidatatang_200zh/ASR/local/text2token.py +++ b/egs/aidatatang_200zh/ASR/local/text2token.py @@ -50,15 +50,15 @@ def get_parser(): "-n", default=1, type=int, - help=( - "number of characters to split, i.e., aabb -> a a b" - " b with -n 1 and aa bb with -n 2" - ), + help="number of characters to split, i.e., \ + aabb -> a a b b with -n 1 and aa bb with -n 2", ) parser.add_argument( "--skip-ncols", "-s", default=0, type=int, help="skip first n columns" ) - parser.add_argument("--space", default="", type=str, help="space symbol") + parser.add_argument( + "--space", default="", type=str, help="space symbol" + ) parser.add_argument( "--non-lang-syms", "-l", @@ -66,7 +66,9 @@ def get_parser(): type=str, help="list of non-linguistic symobles, e.g., etc.", ) - parser.add_argument("text", type=str, default=False, nargs="?", help="input text") + parser.add_argument( + "text", type=str, default=False, nargs="?", help="input text" + ) parser.add_argument( "--trans_type", "-t", @@ -106,7 +108,8 @@ def token2id( if token_type == "lazy_pinyin": text = lazy_pinyin(chars_list) sub_ids = [ - token_table[txt] if txt in token_table else oov_id for txt in text + token_table[txt] if txt in token_table else oov_id + for txt in text ] ids.append(sub_ids) else: # token_type = "pinyin" @@ -132,7 +135,9 @@ def main(): if args.text: f = codecs.open(args.text, encoding="utf-8") else: - f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer) + f = codecs.getreader("utf-8")( + sys.stdin if is_python2 else sys.stdin.buffer + ) sys.stdout = codecs.getwriter("utf-8")( sys.stdout if is_python2 else sys.stdout.buffer diff --git a/egs/aidatatang_200zh/ASR/prepare.sh b/egs/aidatatang_200zh/ASR/prepare.sh index 4749e1b7f..039951354 100755 --- a/egs/aidatatang_200zh/ASR/prepare.sh +++ b/egs/aidatatang_200zh/ASR/prepare.sh @@ -106,10 +106,11 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then if [ ! -f $lang_char_dir/words.txt ]; then ./local/prepare_words.py \ --input-file $lang_char_dir/words_no_ids.txt \ - --output-file $lang_char_dir/words.txt + --output-file $lang_char_dir/words.txt fi if [ ! -f $lang_char_dir/L_disambig.pt ]; then ./local/prepare_char.py fi fi + diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py index 8c94f5bea..6a5b57e24 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -81,12 +81,10 @@ class Aidatatang_200zhAsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description=( - "These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc." - ), + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", ) group.add_argument( "--manifest-dir", @@ -98,91 +96,75 @@ class Aidatatang_200zhAsrDataModule: "--max-duration", type=int, default=200.0, - help=( - "Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM." - ), + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help=( - "When enabled, the batches will come from buckets of " - "similar duration (saves padding frames)." - ), + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", ) group.add_argument( "--num-buckets", type=int, default=300, - help=( - "The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets)." - ), + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help=( - "When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding." - ), + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", ) group.add_argument( "--duration-factor", type=float, default=1.0, - help=( - "Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch." - ), + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", ) group.add_argument( "--gap", type=float, default=1.0, - help=( - "The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used." - ), + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help=( - "When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available." - ), + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", ) group.add_argument( "--shuffle", type=str2bool, default=True, - help=( - "When enabled (=default), the examples will be shuffled for each epoch." - ), + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help=( - "When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it." - ), + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that collect the batches.", + help="The number of training dataloader workers that " + "collect the batches.", ) group.add_argument( @@ -196,22 +178,18 @@ class Aidatatang_200zhAsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help=( - "Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp." - ), + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help=( - "When enabled, select noise from MUSAN and mix it" - "with training dataset. " - ), + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", ) def train_dataloaders( @@ -227,20 +205,24 @@ class Aidatatang_200zhAsrDataModule: The state dict for the training sampler. """ logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + cuts_musan = load_manifest( + self.args.manifest_dir / "musan_cuts.jsonl.gz" + ) transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix( + cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True + ) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - "Using cut concatenation with duration factor " + f"Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -255,7 +237,9 @@ class Aidatatang_200zhAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + logging.info( + f"Time warp factor: {self.args.spec_aug_time_warp_factor}" + ) # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -298,7 +282,9 @@ class Aidatatang_200zhAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -354,7 +340,9 @@ class Aidatatang_200zhAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), return_cuts=self.args.return_cuts, ) else: diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py index 3f582ef04..f0407f429 100755 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py @@ -69,7 +69,11 @@ from beam_search import ( ) from train import get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -88,30 +92,25 @@ def get_parser(): "--epoch", type=int, default=28, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--batch", type=int, default=None, - help=( - "It specifies the batch checkpoint to use for decoding." - "Note: Epoch counts from 0." - ), + help="It specifies the batch checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -193,7 +192,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -249,7 +249,9 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] if params.decoding_method == "fast_beam_search": @@ -264,7 +266,10 @@ def decode_one_batch( ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -310,7 +315,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -381,7 +390,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -414,7 +425,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py index 34f4d3ddf..00b54c39f 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py @@ -62,20 +62,17 @@ def get_parser(): "--epoch", type=int, default=28, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -106,7 +103,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) return parser @@ -175,7 +173,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py index 3c96ed07b..eb5e6b0d4 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py @@ -85,11 +85,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -114,12 +112,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -166,7 +162,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -196,9 +193,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -259,7 +257,9 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -284,7 +284,10 @@ def main(): ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -336,7 +339,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py index c7b1a4266..d46838b68 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py @@ -81,7 +81,9 @@ from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] os.environ["CUDA_LAUNCH_BLOCKING"] = "1" @@ -185,45 +187,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -543,15 +542,22 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -705,7 +711,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -805,7 +813,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 2 ** 22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/aishell/ASR/conformer_ctc/conformer.py b/egs/aishell/ASR/conformer_ctc/conformer.py index f5b5873b4..cb7205e51 100644 --- a/egs/aishell/ASR/conformer_ctc/conformer.py +++ b/egs/aishell/ASR/conformer_ctc/conformer.py @@ -157,7 +157,9 @@ class ConformerEncoderLayer(nn.Module): normalize_before: bool = True, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -175,14 +177,18 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm( + d_model + ) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm(d_model) # for the final output of the block + self.norm_final = nn.LayerNorm( + d_model + ) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -216,7 +222,9 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) + src = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(src) + ) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -335,7 +343,9 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -351,7 +361,9 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -621,9 +633,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -691,25 +703,33 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor" - " instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -746,7 +766,9 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -758,7 +780,9 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -792,9 +816,13 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -817,7 +845,9 @@ class ConvolutionModule(nn.Module): """ - def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: + def __init__( + self, channels: int, kernel_size: int, bias: bool = True + ) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/aishell/ASR/conformer_ctc/decode.py b/egs/aishell/ASR/conformer_ctc/decode.py index a30fa52df..751b7d5b5 100755 --- a/egs/aishell/ASR/conformer_ctc/decode.py +++ b/egs/aishell/ASR/conformer_ctc/decode.py @@ -58,19 +58,16 @@ def get_parser(): "--epoch", type=int, default=49, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=20, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -404,7 +401,9 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -432,7 +431,9 @@ def save_results( # we compute CER for aishell dataset. results_char = [] for res in results: - results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) + results_char.append( + (res[0], list("".join(res[1])), list("".join(res[2]))) + ) with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=enable_log @@ -440,7 +441,9 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + logging.info( + "Wrote detailed error stats to {}".format(errs_filename) + ) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt" @@ -559,7 +562,9 @@ def main(): eos_id=eos_id, ) - save_results(params=params, test_set_name=test_set, results_dict=results_dict) + save_results( + params=params, test_set_name=test_set, results_dict=results_dict + ) logging.info("Done!") diff --git a/egs/aishell/ASR/conformer_ctc/export.py b/egs/aishell/ASR/conformer_ctc/export.py index 9ee405e8b..42b8c29e7 100644 --- a/egs/aishell/ASR/conformer_ctc/export.py +++ b/egs/aishell/ASR/conformer_ctc/export.py @@ -40,20 +40,17 @@ def get_parser(): "--epoch", type=int, default=84, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=25, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -160,7 +157,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/conformer_ctc/pretrained.py b/egs/aishell/ASR/conformer_ctc/pretrained.py index e3d5a20e3..27776bc24 100755 --- a/egs/aishell/ASR/conformer_ctc/pretrained.py +++ b/egs/aishell/ASR/conformer_ctc/pretrained.py @@ -46,29 +46,27 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( "--tokens-file", type=str, - help="Path to tokens.txtUsed only when method is ctc-decoding", + help="Path to tokens.txt" "Used only when method is ctc-decoding", ) parser.add_argument( "--words-file", type=str, - help="Path to words.txtUsed when method is NOT ctc-decoding", + help="Path to words.txt" "Used when method is NOT ctc-decoding", ) parser.add_argument( "--HLG", type=str, - help="Path to HLG.pt.Used when method is NOT ctc-decoding", + help="Path to HLG.pt." "Used when method is NOT ctc-decoding", ) parser.add_argument( @@ -165,12 +163,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) return parser @@ -214,9 +210,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -277,7 +274,9 @@ def main(): logging.info("Decoding started") features = fbank(waves) - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) # Note: We don't use key padding mask for attention during decoding with torch.no_grad(): @@ -372,7 +371,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/conformer_ctc/subsampling.py b/egs/aishell/ASR/conformer_ctc/subsampling.py index 8e0f73d05..542fb0364 100644 --- a/egs/aishell/ASR/conformer_ctc/subsampling.py +++ b/egs/aishell/ASR/conformer_ctc/subsampling.py @@ -42,9 +42,13 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), + nn.Conv2d( + in_channels=1, out_channels=odim, kernel_size=3, stride=2 + ), nn.ReLU(), - nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), + nn.Conv2d( + in_channels=odim, out_channels=odim, kernel_size=3, stride=2 + ), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -128,13 +132,17 @@ class VggSubsampling(nn.Module): ) ) layers.append( - torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) + torch.nn.MaxPool2d( + kernel_size=2, stride=2, padding=0, ceil_mode=True + ) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) + self.out = nn.Linear( + block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim + ) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/aishell/ASR/conformer_ctc/test_subsampling.py b/egs/aishell/ASR/conformer_ctc/test_subsampling.py index 81fa234dd..e3361d0c9 100755 --- a/egs/aishell/ASR/conformer_ctc/test_subsampling.py +++ b/egs/aishell/ASR/conformer_ctc/test_subsampling.py @@ -16,8 +16,9 @@ # limitations under the License. +from subsampling import Conv2dSubsampling +from subsampling import VggSubsampling import torch -from subsampling import Conv2dSubsampling, VggSubsampling def test_conv2d_subsampling(): diff --git a/egs/aishell/ASR/conformer_ctc/train.py b/egs/aishell/ASR/conformer_ctc/train.py index c2cbe6e3b..a228cc1fe 100755 --- a/egs/aishell/ASR/conformer_ctc/train.py +++ b/egs/aishell/ASR/conformer_ctc/train.py @@ -382,7 +382,9 @@ def compute_loss( # # See https://github.com/k2-fsa/icefall/issues/97 # for more details - unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) + unsorted_token_ids = graph_compiler.texts_to_ids( + supervisions["text"] + ) att_loss = mmodel.decoder_forward( encoder_memory, memory_mask, @@ -518,7 +520,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -626,7 +630,9 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/aishell/ASR/conformer_ctc/transformer.py b/egs/aishell/ASR/conformer_ctc/transformer.py index a3e50e385..f93914aaa 100644 --- a/egs/aishell/ASR/conformer_ctc/transformer.py +++ b/egs/aishell/ASR/conformer_ctc/transformer.py @@ -149,7 +149,9 @@ class Transformer(nn.Module): norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) + self.decoder_output_layer = torch.nn.Linear( + d_model, self.decoder_num_class + ) self.decoder_criterion = LabelSmoothingLoss() else: @@ -181,7 +183,9 @@ class Transformer(nn.Module): x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) - encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) + encoder_memory, memory_key_padding_mask = self.run_encoder( + x, supervision + ) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -262,17 +266,23 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -333,17 +343,23 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -616,7 +632,9 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) + raise RuntimeError( + "activation should be relu/gelu, not {}".format(activation) + ) class PositionalEncoding(nn.Module): @@ -818,7 +836,9 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] + lengths = [ + 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) + ] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -839,7 +859,9 @@ def encoder_padding_mask( return mask -def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: +def decoder_padding_mask( + ys_pad: torch.Tensor, ignore_id: int = -1 +) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/aishell/ASR/conformer_mmi/conformer.py b/egs/aishell/ASR/conformer_mmi/conformer.py index f5b5873b4..cb7205e51 100644 --- a/egs/aishell/ASR/conformer_mmi/conformer.py +++ b/egs/aishell/ASR/conformer_mmi/conformer.py @@ -157,7 +157,9 @@ class ConformerEncoderLayer(nn.Module): normalize_before: bool = True, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -175,14 +177,18 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm( + d_model + ) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm(d_model) # for the final output of the block + self.norm_final = nn.LayerNorm( + d_model + ) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -216,7 +222,9 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) + src = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(src) + ) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -335,7 +343,9 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -351,7 +361,9 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -621,9 +633,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -691,25 +703,33 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor" - " instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -746,7 +766,9 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -758,7 +780,9 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -792,9 +816,13 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -817,7 +845,9 @@ class ConvolutionModule(nn.Module): """ - def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: + def __init__( + self, channels: int, kernel_size: int, bias: bool = True + ) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/aishell/ASR/conformer_mmi/decode.py b/egs/aishell/ASR/conformer_mmi/decode.py index a43183063..4db367e36 100755 --- a/egs/aishell/ASR/conformer_mmi/decode.py +++ b/egs/aishell/ASR/conformer_mmi/decode.py @@ -59,19 +59,16 @@ def get_parser(): "--epoch", type=int, default=49, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=20, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -416,7 +413,9 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -444,7 +443,9 @@ def save_results( # we compute CER for aishell dataset. results_char = [] for res in results: - results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) + results_char.append( + (res[0], list("".join(res[1])), list("".join(res[2]))) + ) with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=enable_log @@ -452,7 +453,9 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + logging.info( + "Wrote detailed error stats to {}".format(errs_filename) + ) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt" @@ -547,7 +550,9 @@ def main(): if params.export: logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") + torch.save( + {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" + ) return model.to(device) @@ -576,7 +581,9 @@ def main(): eos_id=eos_id, ) - save_results(params=params, test_set_name=test_set, results_dict=results_dict) + save_results( + params=params, test_set_name=test_set, results_dict=results_dict + ) logging.info("Done!") diff --git a/egs/aishell/ASR/conformer_mmi/subsampling.py b/egs/aishell/ASR/conformer_mmi/subsampling.py index 398837a46..720ed6c22 100644 --- a/egs/aishell/ASR/conformer_mmi/subsampling.py +++ b/egs/aishell/ASR/conformer_mmi/subsampling.py @@ -42,9 +42,13 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), + nn.Conv2d( + in_channels=1, out_channels=odim, kernel_size=3, stride=2 + ), nn.ReLU(), - nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), + nn.Conv2d( + in_channels=odim, out_channels=odim, kernel_size=3, stride=2 + ), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -128,13 +132,17 @@ class VggSubsampling(nn.Module): ) ) layers.append( - torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) + torch.nn.MaxPool2d( + kernel_size=2, stride=2, padding=0, ceil_mode=True + ) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) + self.out = nn.Linear( + block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim + ) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/aishell/ASR/conformer_mmi/train.py b/egs/aishell/ASR/conformer_mmi/train.py index 09cd6e60c..685831d09 100755 --- a/egs/aishell/ASR/conformer_mmi/train.py +++ b/egs/aishell/ASR/conformer_mmi/train.py @@ -511,7 +511,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -623,7 +625,9 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/aishell/ASR/conformer_mmi/transformer.py b/egs/aishell/ASR/conformer_mmi/transformer.py index a3e50e385..f93914aaa 100644 --- a/egs/aishell/ASR/conformer_mmi/transformer.py +++ b/egs/aishell/ASR/conformer_mmi/transformer.py @@ -149,7 +149,9 @@ class Transformer(nn.Module): norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) + self.decoder_output_layer = torch.nn.Linear( + d_model, self.decoder_num_class + ) self.decoder_criterion = LabelSmoothingLoss() else: @@ -181,7 +183,9 @@ class Transformer(nn.Module): x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) - encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) + encoder_memory, memory_key_padding_mask = self.run_encoder( + x, supervision + ) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -262,17 +266,23 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -333,17 +343,23 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -616,7 +632,9 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) + raise RuntimeError( + "activation should be relu/gelu, not {}".format(activation) + ) class PositionalEncoding(nn.Module): @@ -818,7 +836,9 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] + lengths = [ + 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) + ] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -839,7 +859,9 @@ def encoder_padding_mask( return mask -def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: +def decoder_padding_mask( + ys_pad: torch.Tensor, ignore_id: int = -1 +) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py b/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py index 037971927..42700a972 100755 --- a/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py +++ b/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py @@ -87,7 +87,9 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80): ) if "train" in partition: cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + cut_set + + cut_set.perturb_speed(0.9) + + cut_set.perturb_speed(1.1) ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -114,7 +116,9 @@ def get_args(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/aishell/ASR/local/compute_fbank_aishell.py b/egs/aishell/ASR/local/compute_fbank_aishell.py index 115ca1031..deab6c809 100755 --- a/egs/aishell/ASR/local/compute_fbank_aishell.py +++ b/egs/aishell/ASR/local/compute_fbank_aishell.py @@ -83,7 +83,9 @@ def compute_fbank_aishell(num_mel_bins: int = 80): ) if "train" in partition: cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + cut_set + + cut_set.perturb_speed(0.9) + + cut_set.perturb_speed(1.1) ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -109,7 +111,9 @@ def get_args(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/aishell/ASR/local/prepare_char.py b/egs/aishell/ASR/local/prepare_char.py index 6b440dfb3..d9e47d17a 100755 --- a/egs/aishell/ASR/local/prepare_char.py +++ b/egs/aishell/ASR/local/prepare_char.py @@ -86,7 +86,9 @@ def lexicon_to_fst_no_sil( cur_state = loop_state word = word2id[word] - pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces] + pieces = [ + token2id[i] if i in token2id else token2id[""] for i in pieces + ] for i in range(len(pieces) - 1): w = word if i == 0 else eps @@ -140,7 +142,9 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool: return False -def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon: +def generate_lexicon( + token_sym_table: Dict[str, int], words: List[str] +) -> Lexicon: """Generate a lexicon from a word list and token_sym_table. Args: diff --git a/egs/aishell/ASR/local/prepare_lang.py b/egs/aishell/ASR/local/prepare_lang.py index c8cf9b881..e5ae89ec4 100755 --- a/egs/aishell/ASR/local/prepare_lang.py +++ b/egs/aishell/ASR/local/prepare_lang.py @@ -317,7 +317,9 @@ def lexicon_to_fst( def get_args(): parser = argparse.ArgumentParser() - parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone") + parser.add_argument( + "--lang-dir", type=str, help="The lang dir, data/lang_phone" + ) return parser.parse_args() diff --git a/egs/aishell/ASR/local/test_prepare_lang.py b/egs/aishell/ASR/local/test_prepare_lang.py index 74e025ad7..d4cf62bba 100755 --- a/egs/aishell/ASR/local/test_prepare_lang.py +++ b/egs/aishell/ASR/local/test_prepare_lang.py @@ -88,7 +88,9 @@ def test_read_lexicon(filename: str): fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.draw("L.pdf", title="L") - fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id) + fsa_disambig = lexicon_to_fst( + lexicon_disambig, phone2id=phone2id, word2id=word2id + ) fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.draw("L_disambig.pdf", title="L_disambig") diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py index ae926ec66..a12934d55 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py @@ -76,7 +76,11 @@ from beam_search import ( ) from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -114,11 +118,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -186,7 +188,8 @@ def get_parser(): "--context-size", type=int, default=1, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -246,7 +249,9 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) if params.decoding_method == "fast_beam_search": hyp_tokens = fast_beam_search_one_best( @@ -258,7 +263,10 @@ def decode_one_batch( max_contexts=params.max_contexts, max_states=params.max_states, ) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -302,7 +310,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -375,7 +387,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -401,7 +415,9 @@ def save_results( # we compute CER for aishell dataset. results_char = [] for res in results: - results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) + results_char.append( + (res[0], list("".join(res[1])), list("".join(res[2]))) + ) with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=True @@ -412,7 +428,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -456,7 +473,9 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -485,7 +504,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/export.py b/egs/aishell/ASR/pruned_transducer_stateless2/export.py index 5f6888db4..feababdd2 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless2/export.py +++ b/egs/aishell/ASR/pruned_transducer_stateless2/export.py @@ -50,7 +50,11 @@ from pathlib import Path import torch from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.lexicon import Lexicon from icefall.utils import str2bool @@ -83,11 +87,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -118,7 +120,8 @@ def get_parser(): "--context-size", type=int, default=1, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) add_model_arguments(parser) @@ -154,7 +157,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -187,7 +191,9 @@ def main(): model.__class__.forward = torch.jit.ignore(model.__class__.forward) logging.info("Using torch.jit.script") model = torch.jit.script(model) - filename = params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt" + filename = ( + params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt" + ) model.save(str(filename)) logging.info(f"Saved to {filename}") else: @@ -195,14 +201,17 @@ def main(): # Save it using a format so that it can be loaded # by :func:`load_checkpoint` filename = ( - params.exp_dir / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt" + params.exp_dir + / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt" ) torch.save({"model": model.state_dict()}, str(filename)) logging.info(f"Saved to {filename}") if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py index f754a7b9e..3c38e5db7 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py @@ -87,11 +87,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -117,12 +115,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -169,16 +165,15 @@ def get_parser(): "--context-size", type=int, default=1, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", type=int, default=1, - help=( - "Maximum number of symbols per frame. " - "Use only when --method is greedy_search" - ), + help="Maximum number of symbols per frame. " + "Use only when --method is greedy_search", ) add_model_arguments(parser) @@ -201,9 +196,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -260,9 +256,13 @@ def main(): feature_lens = [f.size(0) for f in features] feature_lens = torch.tensor(feature_lens, device=device) - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lens + ) num_waves = encoder_out.size(0) hyp_list = [] @@ -310,7 +310,9 @@ def main(): beam=params.beam_size, ) else: - raise ValueError(f"Unsupported decoding method: {params.method}") + raise ValueError( + f"Unsupported decoding method: {params.method}" + ) hyp_list.append(hyp) hyps = [] @@ -327,7 +329,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/train.py b/egs/aishell/ASR/pruned_transducer_stateless2/train.py index 66ca23035..97d892754 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless2/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless2/train.py @@ -49,6 +49,7 @@ import optim import torch import torch.multiprocessing as mp import torch.nn as nn + from asr_datamodule import AishellAsrDataModule from conformer import Conformer from decoder import Decoder @@ -74,7 +75,9 @@ from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def add_model_arguments(parser: argparse.ArgumentParser): @@ -200,7 +203,8 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need to be changed.", + help="The initial learning rate. This value should not need " + "to be changed.", ) parser.add_argument( @@ -223,45 +227,42 @@ def get_parser(): "--context-size", type=int, default=1, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -560,7 +561,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -588,16 +593,23 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -713,7 +725,9 @@ def train_one_epoch( scaler.update() optimizer.zero_grad() except: # noqa - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + display_and_save_batch( + batch, params=params, graph_compiler=graph_compiler + ) raise if params.print_diagnostics and batch_idx == 5: @@ -877,7 +891,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 2 ** 22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1015,7 +1029,9 @@ def scan_pessimistic_batches_for_oom( f"Failing criterion: {criterion} " f"(={crit_values[criterion]}) ..." ) - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + display_and_save_batch( + batch, params=params, graph_compiler=graph_compiler + ) raise diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py index 6c505940d..d159e420b 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py @@ -121,24 +121,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -206,7 +202,8 @@ def get_parser(): "--context-size", type=int, default=1, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -266,7 +263,9 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) if params.decoding_method == "fast_beam_search": hyp_tokens = fast_beam_search_one_best( @@ -278,7 +277,10 @@ def decode_one_batch( max_contexts=params.max_contexts, max_states=params.max_states, ) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -322,7 +324,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -395,7 +401,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -421,7 +429,9 @@ def save_results( # we compute CER for aishell dataset. results_char = [] for res in results: - results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) + results_char.append( + (res[0], list("".join(res[1])), list("".join(res[2]))) + ) with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=True @@ -432,7 +442,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tCER", file=f) @@ -477,7 +488,9 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -505,12 +518,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -537,12 +551,13 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -571,7 +586,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/export.py b/egs/aishell/ASR/pruned_transducer_stateless3/export.py index e5a5d7c77..566902a85 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/export.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/export.py @@ -88,24 +88,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -136,7 +132,8 @@ def get_parser(): "--context-size", type=int, default=1, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) add_model_arguments(parser) @@ -169,12 +166,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -197,12 +195,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -230,7 +229,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -253,7 +252,9 @@ def main(): model.__class__.forward = torch.jit.ignore(model.__class__.forward) logging.info("Using torch.jit.script") model = torch.jit.script(model) - filename = params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt" + filename = ( + params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt" + ) model.save(str(filename)) logging.info(f"Saved to {filename}") else: @@ -261,14 +262,17 @@ def main(): # Save it using a format so that it can be loaded # by :func:`load_checkpoint` filename = ( - params.exp_dir / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt" + params.exp_dir + / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt" ) torch.save({"model": model.state_dict()}, str(filename)) logging.info(f"Saved to {filename}") if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/model.py b/egs/aishell/ASR/pruned_transducer_stateless3/model.py index a4dda0d6d..e150e8230 100644 --- a/egs/aishell/ASR/pruned_transducer_stateless3/model.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/model.py @@ -84,7 +84,9 @@ class Transducer(nn.Module): self.decoder_datatang = decoder_datatang self.joiner_datatang = joiner_datatang - self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_speed=0.5 + ) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) if decoder_datatang is not None: @@ -177,7 +179,9 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) boundary[:, 2] = y_lens boundary[:, 3] = encoder_out_lens diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py index 109879952..04a0a882a 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py @@ -87,11 +87,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -117,12 +115,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -169,16 +165,15 @@ def get_parser(): "--context-size", type=int, default=1, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", type=int, default=1, - help=( - "Maximum number of symbols per frame. " - "Use only when --method is greedy_search" - ), + help="Maximum number of symbols per frame. " + "Use only when --method is greedy_search", ) add_model_arguments(parser) @@ -201,9 +196,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -261,9 +257,13 @@ def main(): feature_lens = [f.size(0) for f in features] feature_lens = torch.tensor(feature_lens, device=device) - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lens + ) num_waves = encoder_out.size(0) hyp_list = [] @@ -311,7 +311,9 @@ def main(): beam=params.beam_size, ) else: - raise ValueError(f"Unsupported decoding method: {params.method}") + raise ValueError( + f"Unsupported decoding method: {params.method}" + ) hyp_list.append(hyp) hyps = [] @@ -328,7 +330,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/train.py b/egs/aishell/ASR/pruned_transducer_stateless3/train.py index b24f533ff..feaef5cf6 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/train.py @@ -96,7 +96,9 @@ from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def add_model_arguments(parser: argparse.ArgumentParser): @@ -222,7 +224,8 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need to be changed.", + help="The initial learning rate. This value should not need " + "to be changed.", ) parser.add_argument( @@ -245,45 +248,42 @@ def get_parser(): "--context-size", type=int, default=1, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -635,7 +635,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -666,16 +670,23 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -813,7 +824,9 @@ def train_one_epoch( ) # summary stats if datatang_train_dl is not None: - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + tot_loss = ( + tot_loss * (1 - 1 / params.reset_interval) + ) + loss_info if aishell: aishell_tot_loss = ( @@ -834,7 +847,9 @@ def train_one_epoch( scaler.update() optimizer.zero_grad() except: # noqa - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + display_and_save_batch( + batch, params=params, graph_compiler=graph_compiler + ) raise if params.print_diagnostics and batch_idx == 5: @@ -877,7 +892,9 @@ def train_one_epoch( cur_lr = scheduler.get_last_lr()[0] if datatang_train_dl is not None: datatang_str = f"datatang_tot_loss[{datatang_tot_loss}], " - tot_loss_str = f"tot_loss[{tot_loss}], batch size: {batch_size}, " + tot_loss_str = ( + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + ) else: tot_loss_str = "" datatang_str = "" @@ -1050,7 +1067,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 2 ** 22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1059,7 +1076,9 @@ def run(rank, world_size, args): train_cuts = filter_short_and_long_utterances(train_cuts) if args.enable_musan: - cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") + cuts_musan = load_manifest( + Path(args.manifest_dir) / "musan_cuts.jsonl.gz" + ) else: cuts_musan = None @@ -1074,7 +1093,9 @@ def run(rank, world_size, args): if params.datatang_prob > 0: datatang = AIDatatang200zh(manifest_dir=args.manifest_dir) train_datatang_cuts = datatang.train_cuts() - train_datatang_cuts = filter_short_and_long_utterances(train_datatang_cuts) + train_datatang_cuts = filter_short_and_long_utterances( + train_datatang_cuts + ) train_datatang_cuts = train_datatang_cuts.repeat(times=None) datatang_train_dl = asr_datamodule.train_dataloaders( train_datatang_cuts, @@ -1228,7 +1249,9 @@ def scan_pessimistic_batches_for_oom( f"Failing criterion: {criterion} " f"(={crit_values[criterion]}) ..." ) - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + display_and_save_batch( + batch, params=params, graph_compiler=graph_compiler + ) raise diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py index 12ae6e7d4..d24ba6bb7 100644 --- a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -64,12 +64,10 @@ class AishellAsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description=( - "These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc." - ), + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", ) group.add_argument( "--manifest-dir", @@ -81,74 +79,59 @@ class AishellAsrDataModule: "--max-duration", type=int, default=200.0, - help=( - "Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM." - ), + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help=( - "When enabled, the batches will come from buckets of " - "similar duration (saves padding frames)." - ), + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", ) group.add_argument( "--num-buckets", type=int, default=30, - help=( - "The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets)." - ), + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help=( - "When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding." - ), + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", ) group.add_argument( "--duration-factor", type=float, default=1.0, - help=( - "Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch." - ), + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", ) group.add_argument( "--gap", type=float, default=1.0, - help=( - "The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used." - ), + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help=( - "When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available." - ), + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", ) group.add_argument( "--shuffle", type=str2bool, default=True, - help=( - "When enabled (=default), the examples will be shuffled for each epoch." - ), + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", ) group.add_argument( "--drop-last", @@ -160,18 +143,17 @@ class AishellAsrDataModule: "--return-cuts", type=str2bool, default=True, - help=( - "When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it." - ), + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that collect the batches.", + help="The number of training dataloader workers that " + "collect the batches.", ) group.add_argument( @@ -185,40 +167,40 @@ class AishellAsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help=( - "Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp." - ), + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help=( - "When enabled, select noise from MUSAN and mix it" - "with training dataset. " - ), + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", ) def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + cuts_musan = load_manifest( + self.args.manifest_dir / "musan_cuts.jsonl.gz" + ) transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix( + cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True + ) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - "Using cut concatenation with duration factor " + f"Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -233,7 +215,9 @@ class AishellAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + logging.info( + f"Time warp factor: {self.args.spec_aug_time_warp_factor}" + ) # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -276,7 +260,9 @@ class AishellAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -322,7 +308,9 @@ class AishellAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), return_cuts=self.args.return_cuts, ) else: @@ -378,9 +366,13 @@ class AishellAsrDataModule: @lru_cache() def valid_cuts(self) -> CutSet: logging.info("About to get dev cuts") - return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_dev.jsonl.gz") + return load_manifest_lazy( + self.args.manifest_dir / "aishell_cuts_dev.jsonl.gz" + ) @lru_cache() def test_cuts(self) -> List[CutSet]: logging.info("About to get test cuts") - return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_test.jsonl.gz") + return load_manifest_lazy( + self.args.manifest_dir / "aishell_cuts_test.jsonl.gz" + ) diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py b/egs/aishell/ASR/tdnn_lstm_ctc/decode.py index 8ef247438..66b734fc4 100755 --- a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/decode.py @@ -49,19 +49,16 @@ def get_parser(): "--epoch", type=int, default=19, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=5, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( "--method", @@ -268,7 +265,9 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -290,7 +289,9 @@ def save_results( # We compute CER for aishell dataset. results_char = [] for res in results: - results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) + results_char.append( + (res[0], list("".join(res[1])), list("".join(res[2]))) + ) with open(errs_filename, "w") as f: wer = write_error_stats(f, f"{test_set_name}-{key}", results_char) test_set_wers[key] = wer @@ -334,7 +335,9 @@ def main(): logging.info(f"device: {device}") - HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") + ) HLG = HLG.to(device) assert HLG.requires_grad is False @@ -359,7 +362,9 @@ def main(): if params.export: logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") + torch.save( + {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" + ) model.to(device) model.eval() @@ -387,7 +392,9 @@ def main(): lexicon=lexicon, ) - save_results(params=params, test_set_name=test_set, results_dict=results_dict) + save_results( + params=params, test_set_name=test_set, results_dict=results_dict + ) logging.info("Done!") diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/model.py b/egs/aishell/ASR/tdnn_lstm_ctc/model.py index 1731e1ebe..5e04c11b4 100644 --- a/egs/aishell/ASR/tdnn_lstm_ctc/model.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/model.py @@ -66,7 +66,10 @@ class TdnnLstm(nn.Module): nn.BatchNorm1d(num_features=500, affine=False), ) self.lstms = nn.ModuleList( - [nn.LSTM(input_size=500, hidden_size=500, num_layers=1) for _ in range(5)] + [ + nn.LSTM(input_size=500, hidden_size=500, num_layers=1) + for _ in range(5) + ] ) self.lstm_bnorms = nn.ModuleList( [nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)] diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py b/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py index 52f9410cf..9bd810809 100644 --- a/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py @@ -41,11 +41,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -55,7 +53,9 @@ def get_parser(): help="Path to words.txt", ) - parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.") + parser.add_argument( + "--HLG", type=str, required=True, help="Path to HLG.pt." + ) parser.add_argument( "--method", @@ -71,12 +71,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) return parser @@ -114,9 +112,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -174,7 +173,9 @@ def main(): logging.info("Decoding started") features = fbank(waves) - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) features = features.permute(0, 2, 1) # now features is [N, C, T] with torch.no_grad(): @@ -218,7 +219,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/train.py b/egs/aishell/ASR/tdnn_lstm_ctc/train.py index e574cf89b..7619b0551 100755 --- a/egs/aishell/ASR/tdnn_lstm_ctc/train.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/train.py @@ -49,7 +49,12 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + encode_supervisions, + setup_logger, + str2bool, +) def get_parser(): diff --git a/egs/aishell/ASR/transducer_stateless/beam_search.py b/egs/aishell/ASR/transducer_stateless/beam_search.py index de0a8d0f5..9ed9b2ad1 100644 --- a/egs/aishell/ASR/transducer_stateless/beam_search.py +++ b/egs/aishell/ASR/transducer_stateless/beam_search.py @@ -47,9 +47,9 @@ def greedy_search( device = model.device - decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape( - 1, context_size - ) + decoder_input = torch.tensor( + [blank_id] * context_size, device=device + ).reshape(1, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) @@ -81,9 +81,9 @@ def greedy_search( y = logits.argmax().item() if y != blank_id: hyp.append(y) - decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( - 1, context_size - ) + decoder_input = torch.tensor( + [hyp[-context_size:]], device=device + ).reshape(1, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) @@ -157,7 +157,9 @@ class HypothesisList(object): """ if length_norm: - return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) + return max( + self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) + ) else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) @@ -244,9 +246,9 @@ def beam_search( device = model.device - decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape( - 1, context_size - ) + decoder_input = torch.tensor( + [blank_id] * context_size, device=device + ).reshape(1, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) diff --git a/egs/aishell/ASR/transducer_stateless/conformer.py b/egs/aishell/ASR/transducer_stateless/conformer.py index e26c6c385..64114253d 100644 --- a/egs/aishell/ASR/transducer_stateless/conformer.py +++ b/egs/aishell/ASR/transducer_stateless/conformer.py @@ -155,7 +155,9 @@ class ConformerEncoderLayer(nn.Module): normalize_before: bool = True, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -173,14 +175,18 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm( + d_model + ) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm(d_model) # for the final output of the block + self.norm_final = nn.LayerNorm( + d_model + ) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -214,7 +220,9 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) + src = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(src) + ) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -333,7 +341,9 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -349,7 +359,9 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -619,9 +631,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -689,25 +701,33 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor" - " instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -744,7 +764,9 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -756,7 +778,9 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -790,9 +814,13 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -815,7 +843,9 @@ class ConvolutionModule(nn.Module): """ - def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: + def __init__( + self, channels: int, kernel_size: int, bias: bool = True + ) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/aishell/ASR/transducer_stateless/decode.py b/egs/aishell/ASR/transducer_stateless/decode.py index 1f7bb14e1..780b0c4bb 100755 --- a/egs/aishell/ASR/transducer_stateless/decode.py +++ b/egs/aishell/ASR/transducer_stateless/decode.py @@ -52,19 +52,16 @@ def get_parser(): "--epoch", type=int, default=30, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=10, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -102,7 +99,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -229,7 +227,9 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] batch_size = encoder_out.size(0) @@ -248,7 +248,9 @@ def decode_one_batch( model=model, encoder_out=encoder_out_i, beam=params.beam_size ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) hyps.append([lexicon.token_table[i] for i in hyp]) if params.decoding_method == "greedy_search": @@ -317,7 +319,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -342,7 +346,9 @@ def save_results( # we compute CER for aishell dataset. results_char = [] for res in results: - results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) + results_char.append( + (res[0], list("".join(res[1])), list("".join(res[2]))) + ) with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=True @@ -353,7 +359,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tCER", file=f) @@ -423,7 +430,9 @@ def main(): if params.export: logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") + torch.save( + {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" + ) return model.to(device) diff --git a/egs/aishell/ASR/transducer_stateless/decoder.py b/egs/aishell/ASR/transducer_stateless/decoder.py index 70e9e6c96..c2c6552a9 100644 --- a/egs/aishell/ASR/transducer_stateless/decoder.py +++ b/egs/aishell/ASR/transducer_stateless/decoder.py @@ -86,7 +86,9 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) + embedding_out = F.pad( + embedding_out, pad=(self.context_size - 1, 0) + ) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/aishell/ASR/transducer_stateless/export.py b/egs/aishell/ASR/transducer_stateless/export.py index e35b26fe0..4c6519b96 100755 --- a/egs/aishell/ASR/transducer_stateless/export.py +++ b/egs/aishell/ASR/transducer_stateless/export.py @@ -69,20 +69,17 @@ def get_parser(): "--epoch", type=int, default=20, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=10, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -113,7 +110,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) return parser @@ -245,7 +243,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/transducer_stateless/model.py b/egs/aishell/ASR/transducer_stateless/model.py index 591bbe44f..994305fc1 100644 --- a/egs/aishell/ASR/transducer_stateless/model.py +++ b/egs/aishell/ASR/transducer_stateless/model.py @@ -103,7 +103,9 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/aishell/ASR/transducer_stateless/pretrained.py b/egs/aishell/ASR/transducer_stateless/pretrained.py index 8effc9815..db89c4d67 100755 --- a/egs/aishell/ASR/transducer_stateless/pretrained.py +++ b/egs/aishell/ASR/transducer_stateless/pretrained.py @@ -73,11 +73,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -102,12 +100,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -121,7 +117,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -214,9 +211,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -275,7 +273,9 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -319,7 +319,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/transducer_stateless/train.py b/egs/aishell/ASR/transducer_stateless/train.py index 62ffff473..d54157709 100755 --- a/egs/aishell/ASR/transducer_stateless/train.py +++ b/egs/aishell/ASR/transducer_stateless/train.py @@ -126,7 +126,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -388,7 +389,9 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -501,7 +504,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -620,7 +625,9 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/aishell/ASR/transducer_stateless/transformer.py b/egs/aishell/ASR/transducer_stateless/transformer.py index b3ff153c1..e851dcc32 100644 --- a/egs/aishell/ASR/transducer_stateless/transformer.py +++ b/egs/aishell/ASR/transducer_stateless/transformer.py @@ -250,7 +250,9 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) + raise RuntimeError( + "activation should be relu/gelu, not {}".format(activation) + ) class PositionalEncoding(nn.Module): diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py b/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py index 76e209f06..838e53658 100644 --- a/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py @@ -29,7 +29,10 @@ from lhotse.dataset import ( K2SpeechRecognitionDataset, SpecAugment, ) -from lhotse.dataset.input_strategies import OnTheFlyFeatures, PrecomputedFeatures +from lhotse.dataset.input_strategies import ( + OnTheFlyFeatures, + PrecomputedFeatures, +) from torch.utils.data import DataLoader from icefall.utils import str2bool @@ -43,69 +46,59 @@ class AsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description=( - "These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc." - ), + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", ) group.add_argument( "--max-duration", type=int, default=200.0, - help=( - "Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM." - ), + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help=( - "When enabled, the batches will come from buckets of " - "similar duration (saves padding frames)." - ), + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", ) group.add_argument( "--num-buckets", type=int, default=30, - help=( - "The number of buckets for the DynamicBucketingSampler " - "(you might want to increase it for larger datasets)." - ), + help="The number of buckets for the DynamicBucketingSampler " + "(you might want to increase it for larger datasets).", ) group.add_argument( "--shuffle", type=str2bool, default=True, - help=( - "When enabled (=default), the examples will be shuffled for each epoch." - ), + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help=( - "When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it." - ), + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that collect the batches.", + help="The number of training dataloader workers that " + "collect the batches.", ) group.add_argument( @@ -119,22 +112,18 @@ class AsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help=( - "Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp." - ), + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help=( - "When enabled, select noise from MUSAN and mix it" - "with training dataset. " - ), + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", ) group.add_argument( @@ -148,11 +137,9 @@ class AsrDataModule: "--on-the-fly-feats", type=str2bool, default=False, - help=( - "When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available. Used only in dev/test CutSet" - ), + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available. Used only in dev/test CutSet", ) def train_dataloaders( @@ -175,7 +162,9 @@ class AsrDataModule: if cuts_musan is not None: logging.info("Enable MUSAN") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix( + cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True + ) ) else: logging.info("Disable MUSAN") @@ -184,7 +173,9 @@ class AsrDataModule: if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + logging.info( + f"Time warp factor: {self.args.spec_aug_time_warp_factor}" + ) # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -261,7 +252,9 @@ class AsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), return_cuts=self.args.return_cuts, ) else: diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py index fd4cb8385..ea3f94fd8 100755 --- a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py @@ -93,19 +93,16 @@ def get_parser(): "--epoch", type=int, default=30, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=10, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -173,7 +170,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -229,7 +227,9 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) if params.decoding_method == "fast_beam_search": hyp_tokens = fast_beam_search_one_best( @@ -241,7 +241,10 @@ def decode_one_batch( max_contexts=params.max_contexts, max_states=params.max_states, ) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -285,7 +288,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -358,7 +365,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -384,7 +393,9 @@ def save_results( # we compute CER for aishell dataset. results_char = [] for res in results: - results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) + results_char.append( + (res[0], list("".join(res[1])), list("".join(res[2]))) + ) with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=True @@ -395,7 +406,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tCER", file=f) @@ -436,7 +448,9 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/export.py b/egs/aishell/ASR/transducer_stateless_modified-2/export.py index 32481829c..3bd2ceb11 100755 --- a/egs/aishell/ASR/transducer_stateless_modified-2/export.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/export.py @@ -68,20 +68,17 @@ def get_parser(): "--epoch", type=int, default=20, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=10, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -112,7 +109,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) return parser @@ -243,7 +241,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py index 55701a007..a95a4bc52 100755 --- a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py @@ -87,11 +87,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -117,12 +115,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -169,16 +165,15 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", type=int, default=1, - help=( - "Maximum number of symbols per frame. " - "Use only when --method is greedy_search" - ), + help="Maximum number of symbols per frame. " + "Use only when --method is greedy_search", ) return parser @@ -199,9 +194,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -258,9 +254,13 @@ def main(): feature_lens = [f.size(0) for f in features] feature_lens = torch.tensor(feature_lens, device=device) - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lens + ) num_waves = encoder_out.size(0) hyp_list = [] @@ -308,7 +308,9 @@ def main(): beam=params.beam_size, ) else: - raise ValueError(f"Unsupported decoding method: {params.method}") + raise ValueError( + f"Unsupported decoding method: {params.method}" + ) hyp_list.append(hyp) hyps = [] @@ -325,7 +327,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/train.py b/egs/aishell/ASR/transducer_stateless_modified-2/train.py index 8fb7d1e49..225d0d709 100755 --- a/egs/aishell/ASR/transducer_stateless_modified-2/train.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/train.py @@ -149,7 +149,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -167,7 +168,8 @@ def get_parser(): "--datatang-prob", type=float, default=0.2, - help="The probability to select a batch from the aidatatang_200zh dataset", + help="The probability to select a batch from the " + "aidatatang_200zh dataset", ) return parser @@ -447,7 +449,9 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -601,7 +605,9 @@ def train_one_epoch( f"train/current_{prefix}_", params.batch_idx_train, ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) aishell_tot_loss.write_summary( tb_writer, "train/aishell_tot_", params.batch_idx_train ) @@ -729,7 +735,9 @@ def run(rank, world_size, args): train_datatang_cuts = train_datatang_cuts.repeat(times=None) if args.enable_musan: - cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") + cuts_musan = load_manifest( + Path(args.manifest_dir) / "musan_cuts.jsonl.gz" + ) else: cuts_musan = None @@ -768,7 +776,9 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/aishell/ASR/transducer_stateless_modified/decode.py b/egs/aishell/ASR/transducer_stateless_modified/decode.py index 1e41942da..65fcda873 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/decode.py +++ b/egs/aishell/ASR/transducer_stateless_modified/decode.py @@ -94,19 +94,16 @@ def get_parser(): "--epoch", type=int, default=30, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=10, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -174,7 +171,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -233,7 +231,9 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) if params.decoding_method == "fast_beam_search": hyp_tokens = fast_beam_search_one_best( @@ -245,7 +245,10 @@ def decode_one_batch( max_contexts=params.max_contexts, max_states=params.max_states, ) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -289,7 +292,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -362,7 +369,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -388,7 +397,9 @@ def save_results( # we compute CER for aishell dataset. results_char = [] for res in results: - results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) + results_char.append( + (res[0], list("".join(res[1])), list("".join(res[2]))) + ) with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=True @@ -399,7 +410,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tCER", file=f) @@ -440,7 +452,9 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" diff --git a/egs/aishell/ASR/transducer_stateless_modified/export.py b/egs/aishell/ASR/transducer_stateless_modified/export.py index ca1d4bd4a..11335a834 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/export.py +++ b/egs/aishell/ASR/transducer_stateless_modified/export.py @@ -68,20 +68,17 @@ def get_parser(): "--epoch", type=int, default=20, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=10, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -112,7 +109,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) return parser @@ -243,7 +241,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py index 038090461..262e822c2 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py +++ b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py @@ -87,11 +87,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -117,12 +115,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -169,16 +165,15 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", type=int, default=1, - help=( - "Maximum number of symbols per frame. " - "Use only when --method is greedy_search" - ), + help="Maximum number of symbols per frame. " + "Use only when --method is greedy_search", ) return parser @@ -199,9 +194,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -258,9 +254,13 @@ def main(): feature_lens = [f.size(0) for f in features] feature_lens = torch.tensor(feature_lens, device=device) - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lens + ) num_waves = encoder_out.size(0) hyp_list = [] @@ -308,7 +308,9 @@ def main(): beam=params.beam_size, ) else: - raise ValueError(f"Unsupported decoding method: {params.method}") + raise ValueError( + f"Unsupported decoding method: {params.method}" + ) hyp_list.append(hyp) hyps = [] @@ -325,7 +327,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/transducer_stateless_modified/train.py b/egs/aishell/ASR/transducer_stateless_modified/train.py index 5f116f2bd..d3ffccafa 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/train.py +++ b/egs/aishell/ASR/transducer_stateless_modified/train.py @@ -142,7 +142,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -413,7 +414,9 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -526,7 +529,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -652,7 +657,9 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/aishell2/ASR/local/__init__.py b/egs/aishell2/ASR/local/__init__.py old mode 100644 new mode 100755 diff --git a/egs/aishell2/ASR/local/compute_fbank_aishell2.py b/egs/aishell2/ASR/local/compute_fbank_aishell2.py index ec0c584ca..d8d3622bd 100755 --- a/egs/aishell2/ASR/local/compute_fbank_aishell2.py +++ b/egs/aishell2/ASR/local/compute_fbank_aishell2.py @@ -83,7 +83,9 @@ def compute_fbank_aishell2(num_mel_bins: int = 80): ) if "train" in partition: cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + cut_set + + cut_set.perturb_speed(0.9) + + cut_set.perturb_speed(1.1) ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -109,7 +111,9 @@ def get_args(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/__init__.py b/egs/aishell2/ASR/pruned_transducer_stateless5/__init__.py old mode 100644 new mode 100755 diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py old mode 100644 new mode 100755 index e8966b554..b7a21f579 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -76,12 +76,10 @@ class AiShell2AsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description=( - "These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc." - ), + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", ) group.add_argument( "--manifest-dir", @@ -93,74 +91,59 @@ class AiShell2AsrDataModule: "--max-duration", type=int, default=200.0, - help=( - "Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM." - ), + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help=( - "When enabled, the batches will come from buckets of " - "similar duration (saves padding frames)." - ), + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", ) group.add_argument( "--num-buckets", type=int, default=30, - help=( - "The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets)." - ), + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help=( - "When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding." - ), + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", ) group.add_argument( "--duration-factor", type=float, default=1.0, - help=( - "Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch." - ), + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", ) group.add_argument( "--gap", type=float, default=1.0, - help=( - "The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used." - ), + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help=( - "When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available." - ), + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", ) group.add_argument( "--shuffle", type=str2bool, default=True, - help=( - "When enabled (=default), the examples will be shuffled for each epoch." - ), + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", ) group.add_argument( "--drop-last", @@ -172,18 +155,17 @@ class AiShell2AsrDataModule: "--return-cuts", type=str2bool, default=True, - help=( - "When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it." - ), + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that collect the batches.", + help="The number of training dataloader workers that " + "collect the batches.", ) group.add_argument( @@ -197,22 +179,18 @@ class AiShell2AsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help=( - "Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp." - ), + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help=( - "When enabled, select noise from MUSAN and mix it" - "with training dataset. " - ), + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", ) group.add_argument( @@ -238,16 +216,20 @@ class AiShell2AsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + cuts_musan = load_manifest( + self.args.manifest_dir / "musan_cuts.jsonl.gz" + ) transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix( + cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True + ) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - "Using cut concatenation with duration factor " + f"Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -262,7 +244,9 @@ class AiShell2AsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + logging.info( + f"Time warp factor: {self.args.spec_aug_time_warp_factor}" + ) # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -306,7 +290,9 @@ class AiShell2AsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -362,7 +348,9 @@ class AiShell2AsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), return_cuts=self.args.return_cuts, ) else: @@ -418,7 +406,9 @@ class AiShell2AsrDataModule: @lru_cache() def valid_cuts(self) -> CutSet: logging.info("About to gen cuts from aishell2_cuts_dev.jsonl.gz") - return load_manifest_lazy(self.args.manifest_dir / "aishell2_cuts_dev.jsonl.gz") + return load_manifest_lazy( + self.args.manifest_dir / "aishell2_cuts_dev.jsonl.gz" + ) @lru_cache() def test_cuts(self) -> CutSet: diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py index 64b64d1b1..915737f4a 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py @@ -168,24 +168,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -273,7 +269,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -351,7 +348,9 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] if params.decoding_method == "fast_beam_search": @@ -410,7 +409,10 @@ def decode_one_batch( ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -536,7 +538,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -569,7 +573,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -620,7 +625,9 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -654,12 +661,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -682,12 +690,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -715,7 +724,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -740,7 +749,9 @@ def main(): ) decoding_graph.scores *= params.ngram_lm_scale else: - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/export.py b/egs/aishell2/ASR/pruned_transducer_stateless5/export.py index 547ce2069..bc7bd71cb 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/export.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/export.py @@ -89,24 +89,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -137,7 +133,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) add_model_arguments(parser) @@ -170,12 +167,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -198,12 +196,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -231,7 +230,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -267,7 +266,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py b/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py index 4b16511e8..09de1bece 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py @@ -81,11 +81,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -111,12 +109,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -163,7 +159,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -194,9 +191,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -256,11 +254,15 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lengths + ) num_waves = encoder_out.size(0) hyps = [] @@ -332,7 +334,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py index d37e7bdca..838a0497f 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py @@ -92,7 +92,9 @@ from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def add_model_arguments(parser: argparse.ArgumentParser): @@ -218,7 +220,8 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need to be changed.", + help="The initial learning rate. This value should not need " + "to be changed.", ) parser.add_argument( @@ -241,45 +244,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -603,7 +603,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -632,16 +636,23 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -760,7 +771,9 @@ def train_one_epoch( scaler.update() optimizer.zero_grad() except: # noqa - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + display_and_save_batch( + batch, params=params, graph_compiler=graph_compiler + ) raise if params.print_diagnostics and batch_idx == 5: @@ -816,7 +829,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -924,7 +939,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 2 ** 22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1089,7 +1104,9 @@ def scan_pessimistic_batches_for_oom( f"Failing criterion: {criterion} " f"(={crit_values[criterion]}) ..." ) - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + display_and_save_batch( + batch, params=params, graph_compiler=graph_compiler + ) raise diff --git a/egs/aishell4/ASR/local/compute_fbank_aishell4.py b/egs/aishell4/ASR/local/compute_fbank_aishell4.py index 400c406f0..3f50d9e3e 100755 --- a/egs/aishell4/ASR/local/compute_fbank_aishell4.py +++ b/egs/aishell4/ASR/local/compute_fbank_aishell4.py @@ -85,7 +85,9 @@ def compute_fbank_aishell4(num_mel_bins: int = 80): ) if "train" in partition: cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + cut_set + + cut_set.perturb_speed(0.9) + + cut_set.perturb_speed(1.1) ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -118,7 +120,9 @@ def get_args(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/aishell4/ASR/local/prepare_char.py b/egs/aishell4/ASR/local/prepare_char.py index 6b440dfb3..d9e47d17a 100755 --- a/egs/aishell4/ASR/local/prepare_char.py +++ b/egs/aishell4/ASR/local/prepare_char.py @@ -86,7 +86,9 @@ def lexicon_to_fst_no_sil( cur_state = loop_state word = word2id[word] - pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces] + pieces = [ + token2id[i] if i in token2id else token2id[""] for i in pieces + ] for i in range(len(pieces) - 1): w = word if i == 0 else eps @@ -140,7 +142,9 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool: return False -def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon: +def generate_lexicon( + token_sym_table: Dict[str, int], words: List[str] +) -> Lexicon: """Generate a lexicon from a word list and token_sym_table. Args: diff --git a/egs/aishell4/ASR/local/prepare_lang.py b/egs/aishell4/ASR/local/prepare_lang.py index c8cf9b881..e5ae89ec4 100755 --- a/egs/aishell4/ASR/local/prepare_lang.py +++ b/egs/aishell4/ASR/local/prepare_lang.py @@ -317,7 +317,9 @@ def lexicon_to_fst( def get_args(): parser = argparse.ArgumentParser() - parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone") + parser.add_argument( + "--lang-dir", type=str, help="The lang dir, data/lang_phone" + ) return parser.parse_args() diff --git a/egs/aishell4/ASR/local/test_prepare_lang.py b/egs/aishell4/ASR/local/test_prepare_lang.py index 74e025ad7..d4cf62bba 100755 --- a/egs/aishell4/ASR/local/test_prepare_lang.py +++ b/egs/aishell4/ASR/local/test_prepare_lang.py @@ -88,7 +88,9 @@ def test_read_lexicon(filename: str): fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.draw("L.pdf", title="L") - fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id) + fsa_disambig = lexicon_to_fst( + lexicon_disambig, phone2id=phone2id, word2id=word2id + ) fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.draw("L_disambig.pdf", title="L_disambig") diff --git a/egs/aishell4/ASR/local/text2token.py b/egs/aishell4/ASR/local/text2token.py index 2be639b7a..71be2a613 100755 --- a/egs/aishell4/ASR/local/text2token.py +++ b/egs/aishell4/ASR/local/text2token.py @@ -50,15 +50,15 @@ def get_parser(): "-n", default=1, type=int, - help=( - "number of characters to split, i.e., aabb -> a a b" - " b with -n 1 and aa bb with -n 2" - ), + help="number of characters to split, i.e., \ + aabb -> a a b b with -n 1 and aa bb with -n 2", ) parser.add_argument( "--skip-ncols", "-s", default=0, type=int, help="skip first n columns" ) - parser.add_argument("--space", default="", type=str, help="space symbol") + parser.add_argument( + "--space", default="", type=str, help="space symbol" + ) parser.add_argument( "--non-lang-syms", "-l", @@ -66,7 +66,9 @@ def get_parser(): type=str, help="list of non-linguistic symobles, e.g., etc.", ) - parser.add_argument("text", type=str, default=False, nargs="?", help="input text") + parser.add_argument( + "text", type=str, default=False, nargs="?", help="input text" + ) parser.add_argument( "--trans_type", "-t", @@ -106,7 +108,8 @@ def token2id( if token_type == "lazy_pinyin": text = lazy_pinyin(chars_list) sub_ids = [ - token_table[txt] if txt in token_table else oov_id for txt in text + token_table[txt] if txt in token_table else oov_id + for txt in text ] ids.append(sub_ids) else: # token_type = "pinyin" @@ -132,7 +135,9 @@ def main(): if args.text: f = codecs.open(args.text, encoding="utf-8") else: - f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer) + f = codecs.getreader("utf-8")( + sys.stdin if is_python2 else sys.stdin.buffer + ) sys.stdout = codecs.getwriter("utf-8")( sys.stdout if is_python2 else sys.stdout.buffer diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py index 84c7f0443..7aa53ddda 100644 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -74,12 +74,10 @@ class Aishell4AsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description=( - "These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc." - ), + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", ) group.add_argument( @@ -93,81 +91,66 @@ class Aishell4AsrDataModule: "--max-duration", type=int, default=200.0, - help=( - "Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM." - ), + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help=( - "When enabled, the batches will come from buckets of " - "similar duration (saves padding frames)." - ), + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", ) group.add_argument( "--num-buckets", type=int, default=300, - help=( - "The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets)." - ), + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help=( - "When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding." - ), + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", ) group.add_argument( "--duration-factor", type=float, default=1.0, - help=( - "Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch." - ), + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", ) group.add_argument( "--gap", type=float, default=1.0, - help=( - "The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used." - ), + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help=( - "When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available." - ), + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", ) group.add_argument( "--shuffle", type=str2bool, default=True, - help=( - "When enabled (=default), the examples will be shuffled for each epoch." - ), + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", ) group.add_argument( @@ -181,18 +164,17 @@ class Aishell4AsrDataModule: "--return-cuts", type=str2bool, default=True, - help=( - "When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it." - ), + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that collect the batches.", + help="The number of training dataloader workers that " + "collect the batches.", ) group.add_argument( @@ -206,22 +188,18 @@ class Aishell4AsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help=( - "Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp." - ), + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help=( - "When enabled, select noise from MUSAN and mix it" - "with training dataset. " - ), + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", ) group.add_argument( @@ -244,20 +222,24 @@ class Aishell4AsrDataModule: The state dict for the training sampler. """ logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + cuts_musan = load_manifest( + self.args.manifest_dir / "musan_cuts.jsonl.gz" + ) transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix( + cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True + ) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - "Using cut concatenation with duration factor " + f"Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -272,7 +254,9 @@ class Aishell4AsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + logging.info( + f"Time warp factor: {self.args.spec_aug_time_warp_factor}" + ) # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -316,7 +300,9 @@ class Aishell4AsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -373,7 +359,9 @@ class Aishell4AsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), return_cuts=self.args.return_cuts, ) else: diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py index 616a88937..14e44c7d9 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py @@ -117,24 +117,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -205,7 +201,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -263,7 +260,9 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] if params.decoding_method == "fast_beam_search": @@ -278,7 +277,10 @@ def decode_one_batch( ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -324,7 +326,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -395,7 +401,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -428,7 +436,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -471,7 +480,9 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -499,12 +510,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -531,12 +543,13 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -565,7 +578,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/export.py b/egs/aishell4/ASR/pruned_transducer_stateless5/export.py index 3c580ff7b..993341131 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/export.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/export.py @@ -89,24 +89,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -140,7 +136,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) add_model_arguments(parser) @@ -172,12 +169,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -204,12 +202,13 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -238,7 +237,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -277,7 +276,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py b/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py index 8151442af..1fa893637 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py @@ -94,11 +94,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -124,12 +122,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -176,7 +172,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -207,9 +204,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -268,11 +266,15 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lengths + ) num_waves = encoder_out.size(0) hyps = [] @@ -304,7 +306,10 @@ def main(): for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -345,7 +350,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py index aacd23ecd..0a48b9059 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py @@ -85,7 +85,9 @@ from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def add_model_arguments(parser: argparse.ArgumentParser): @@ -211,7 +213,8 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need to be changed.", + help="The initial learning rate. This value should not need " + "to be changed.", ) parser.add_argument( @@ -234,45 +237,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -599,7 +599,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -629,15 +633,22 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -816,7 +827,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -924,7 +937,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 2 ** 22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py b/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py index 96115a230..af926aa53 100755 --- a/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py +++ b/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py @@ -84,7 +84,9 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80): ) if "train" in partition: cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + cut_set + + cut_set.perturb_speed(0.9) + + cut_set.perturb_speed(1.1) ) cur_num_jobs = num_jobs if ex is None else 80 cur_num_jobs = min(cur_num_jobs, len(cut_set)) @@ -119,7 +121,9 @@ def get_args(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/alimeeting/ASR/local/prepare_char.py b/egs/alimeeting/ASR/local/prepare_char.py index 6b440dfb3..d9e47d17a 100755 --- a/egs/alimeeting/ASR/local/prepare_char.py +++ b/egs/alimeeting/ASR/local/prepare_char.py @@ -86,7 +86,9 @@ def lexicon_to_fst_no_sil( cur_state = loop_state word = word2id[word] - pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces] + pieces = [ + token2id[i] if i in token2id else token2id[""] for i in pieces + ] for i in range(len(pieces) - 1): w = word if i == 0 else eps @@ -140,7 +142,9 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool: return False -def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon: +def generate_lexicon( + token_sym_table: Dict[str, int], words: List[str] +) -> Lexicon: """Generate a lexicon from a word list and token_sym_table. Args: diff --git a/egs/alimeeting/ASR/local/prepare_lang.py b/egs/alimeeting/ASR/local/prepare_lang.py index c8cf9b881..e5ae89ec4 100755 --- a/egs/alimeeting/ASR/local/prepare_lang.py +++ b/egs/alimeeting/ASR/local/prepare_lang.py @@ -317,7 +317,9 @@ def lexicon_to_fst( def get_args(): parser = argparse.ArgumentParser() - parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone") + parser.add_argument( + "--lang-dir", type=str, help="The lang dir, data/lang_phone" + ) return parser.parse_args() diff --git a/egs/alimeeting/ASR/local/test_prepare_lang.py b/egs/alimeeting/ASR/local/test_prepare_lang.py index 74e025ad7..d4cf62bba 100755 --- a/egs/alimeeting/ASR/local/test_prepare_lang.py +++ b/egs/alimeeting/ASR/local/test_prepare_lang.py @@ -88,7 +88,9 @@ def test_read_lexicon(filename: str): fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.draw("L.pdf", title="L") - fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id) + fsa_disambig = lexicon_to_fst( + lexicon_disambig, phone2id=phone2id, word2id=word2id + ) fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.draw("L_disambig.pdf", title="L_disambig") diff --git a/egs/alimeeting/ASR/local/text2segments.py b/egs/alimeeting/ASR/local/text2segments.py index 27b904fc8..7c1019aa8 100644 --- a/egs/alimeeting/ASR/local/text2segments.py +++ b/egs/alimeeting/ASR/local/text2segments.py @@ -30,8 +30,8 @@ with word segmenting: import argparse -import jieba import paddle +import jieba from tqdm import tqdm paddle.enable_static() diff --git a/egs/alimeeting/ASR/local/text2token.py b/egs/alimeeting/ASR/local/text2token.py index 2be639b7a..71be2a613 100755 --- a/egs/alimeeting/ASR/local/text2token.py +++ b/egs/alimeeting/ASR/local/text2token.py @@ -50,15 +50,15 @@ def get_parser(): "-n", default=1, type=int, - help=( - "number of characters to split, i.e., aabb -> a a b" - " b with -n 1 and aa bb with -n 2" - ), + help="number of characters to split, i.e., \ + aabb -> a a b b with -n 1 and aa bb with -n 2", ) parser.add_argument( "--skip-ncols", "-s", default=0, type=int, help="skip first n columns" ) - parser.add_argument("--space", default="", type=str, help="space symbol") + parser.add_argument( + "--space", default="", type=str, help="space symbol" + ) parser.add_argument( "--non-lang-syms", "-l", @@ -66,7 +66,9 @@ def get_parser(): type=str, help="list of non-linguistic symobles, e.g., etc.", ) - parser.add_argument("text", type=str, default=False, nargs="?", help="input text") + parser.add_argument( + "text", type=str, default=False, nargs="?", help="input text" + ) parser.add_argument( "--trans_type", "-t", @@ -106,7 +108,8 @@ def token2id( if token_type == "lazy_pinyin": text = lazy_pinyin(chars_list) sub_ids = [ - token_table[txt] if txt in token_table else oov_id for txt in text + token_table[txt] if txt in token_table else oov_id + for txt in text ] ids.append(sub_ids) else: # token_type = "pinyin" @@ -132,7 +135,9 @@ def main(): if args.text: f = codecs.open(args.text, encoding="utf-8") else: - f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer) + f = codecs.getreader("utf-8")( + sys.stdin if is_python2 else sys.stdin.buffer + ) sys.stdout = codecs.getwriter("utf-8")( sys.stdout if is_python2 else sys.stdout.buffer diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py index d0467a29e..bf6faad7a 100644 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -81,12 +81,10 @@ class AlimeetingAsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description=( - "These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc." - ), + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", ) group.add_argument( "--manifest-dir", @@ -98,91 +96,75 @@ class AlimeetingAsrDataModule: "--max-duration", type=int, default=200.0, - help=( - "Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM." - ), + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help=( - "When enabled, the batches will come from buckets of " - "similar duration (saves padding frames)." - ), + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", ) group.add_argument( "--num-buckets", type=int, default=300, - help=( - "The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets)." - ), + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help=( - "When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding." - ), + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", ) group.add_argument( "--duration-factor", type=float, default=1.0, - help=( - "Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch." - ), + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", ) group.add_argument( "--gap", type=float, default=1.0, - help=( - "The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used." - ), + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help=( - "When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available." - ), + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", ) group.add_argument( "--shuffle", type=str2bool, default=True, - help=( - "When enabled (=default), the examples will be shuffled for each epoch." - ), + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help=( - "When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it." - ), + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that collect the batches.", + help="The number of training dataloader workers that " + "collect the batches.", ) group.add_argument( @@ -196,22 +178,18 @@ class AlimeetingAsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help=( - "Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp." - ), + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help=( - "When enabled, select noise from MUSAN and mix it" - "with training dataset. " - ), + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", ) def train_dataloaders( @@ -227,20 +205,24 @@ class AlimeetingAsrDataModule: The state dict for the training sampler. """ logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + cuts_musan = load_manifest( + self.args.manifest_dir / "musan_cuts.jsonl.gz" + ) transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix( + cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True + ) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - "Using cut concatenation with duration factor " + f"Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -255,7 +237,9 @@ class AlimeetingAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + logging.info( + f"Time warp factor: {self.args.spec_aug_time_warp_factor}" + ) # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -298,7 +282,9 @@ class AlimeetingAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -355,7 +341,9 @@ class AlimeetingAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), return_cuts=self.args.return_cuts, ) else: diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py index ffaca1021..6358fe970 100755 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py @@ -70,7 +70,11 @@ from beam_search import ( from lhotse.cut import Cut from train import get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -89,30 +93,25 @@ def get_parser(): "--epoch", type=int, default=28, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--batch", type=int, default=None, - help=( - "It specifies the batch checkpoint to use for decoding." - "Note: Epoch counts from 0." - ), + help="It specifies the batch checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -194,7 +193,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -249,7 +249,9 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] if params.decoding_method == "fast_beam_search": @@ -264,7 +266,10 @@ def decode_one_batch( ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -310,7 +315,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -381,7 +390,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -414,7 +425,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -551,7 +563,8 @@ def main(): ) dev_shards = [ - str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar"))) + str(path) + for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar"))) ] cuts_dev_webdataset = CutSet.from_webdataset( dev_shards, @@ -561,7 +574,8 @@ def main(): ) test_shards = [ - str(path) for path in sorted(glob.glob(os.path.join(test, "shared-*.tar"))) + str(path) + for path in sorted(glob.glob(os.path.join(test, "shared-*.tar"))) ] cuts_test_webdataset = CutSet.from_webdataset( test_shards, @@ -574,7 +588,9 @@ def main(): return 1.0 <= c.duration cuts_dev_webdataset = cuts_dev_webdataset.filter(remove_short_and_long_utt) - cuts_test_webdataset = cuts_test_webdataset.filter(remove_short_and_long_utt) + cuts_test_webdataset = cuts_test_webdataset.filter( + remove_short_and_long_utt + ) dev_dl = alimeeting.valid_dataloaders(cuts_dev_webdataset) test_dl = alimeeting.test_dataloaders(cuts_test_webdataset) diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py index 482e52d83..8beec1b8a 100644 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py @@ -62,20 +62,17 @@ def get_parser(): "--epoch", type=int, default=28, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -106,7 +103,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) return parser @@ -175,7 +173,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py index afbf0960a..93b1e1f57 100644 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py @@ -85,11 +85,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -114,12 +112,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -166,7 +162,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -196,9 +193,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -259,7 +257,9 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -284,7 +284,10 @@ def main(): ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -336,7 +339,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py index 158ea9c1b..81a0ede7f 100644 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py @@ -81,7 +81,9 @@ from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] os.environ["CUDA_LAUNCH_BLOCKING"] = "1" @@ -185,45 +187,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -543,15 +542,22 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -705,7 +711,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -805,7 +813,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 2 ** 22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/csj/ASR/.gitignore b/egs/csj/ASR/.gitignore index cd0e20c4c..5d965832e 100644 --- a/egs/csj/ASR/.gitignore +++ b/egs/csj/ASR/.gitignore @@ -5,4 +5,4 @@ notify_tg.py finetune_* misc.ini .vscode/* -offline/* +offline/* \ No newline at end of file diff --git a/egs/csj/ASR/local/compute_fbank_csj.py b/egs/csj/ASR/local/compute_fbank_csj.py index 036ce925f..994dedbdd 100644 --- a/egs/csj/ASR/local/compute_fbank_csj.py +++ b/egs/csj/ASR/local/compute_fbank_csj.py @@ -25,10 +25,15 @@ from random import Random from typing import List, Tuple import torch -from lhotse import ( # fmt: off; See the following for why LilcomChunkyWriter is preferred; https://github.com/k2-fsa/icefall/pull/404; https://github.com/lhotse-speech/lhotse/pull/527; fmt: on +from lhotse import ( CutSet, Fbank, FbankConfig, + # fmt: off + # See the following for why LilcomChunkyWriter is preferred + # https://github.com/k2-fsa/icefall/pull/404 + # https://github.com/lhotse-speech/lhotse/pull/527 + # fmt: on LilcomChunkyWriter, RecordingSet, SupervisionSet, @@ -76,13 +81,17 @@ def make_cutset_blueprints( cut_sets.append((f"eval{i}", cut_set)) # Create train and valid cuts - logging.info("Loading, trimming, and shuffling the remaining core+noncore cuts.") + logging.info( + "Loading, trimming, and shuffling the remaining core+noncore cuts." + ) recording_set = RecordingSet.from_file( manifest_dir / "csj_recordings_core.jsonl.gz" ) + RecordingSet.from_file(manifest_dir / "csj_recordings_noncore.jsonl.gz") supervision_set = SupervisionSet.from_file( manifest_dir / "csj_supervisions_core.jsonl.gz" - ) + SupervisionSet.from_file(manifest_dir / "csj_supervisions_noncore.jsonl.gz") + ) + SupervisionSet.from_file( + manifest_dir / "csj_supervisions_noncore.jsonl.gz" + ) cut_set = CutSet.from_manifests( recordings=recording_set, @@ -92,12 +101,15 @@ def make_cutset_blueprints( cut_set = cut_set.shuffle(Random(RNG_SEED)) logging.info( - f"Creating valid and train cuts from core and noncore,split at {split}." + "Creating valid and train cuts from core and noncore," + f"split at {split}." ) valid_set = CutSet.from_cuts(islice(cut_set, 0, split)) train_set = CutSet.from_cuts(islice(cut_set, split, None)) - train_set = train_set + train_set.perturb_speed(0.9) + train_set.perturb_speed(1.1) + train_set = ( + train_set + train_set.perturb_speed(0.9) + train_set.perturb_speed(1.1) + ) cut_sets.extend([("valid", valid_set), ("train", train_set)]) @@ -110,9 +122,15 @@ def get_args(): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument("--manifest-dir", type=Path, help="Path to save manifests") - parser.add_argument("--fbank-dir", type=Path, help="Path to save fbank features") - parser.add_argument("--split", type=int, default=4000, help="Split at this index") + parser.add_argument( + "--manifest-dir", type=Path, help="Path to save manifests" + ) + parser.add_argument( + "--fbank-dir", type=Path, help="Path to save fbank features" + ) + parser.add_argument( + "--split", type=int, default=4000, help="Split at this index" + ) return parser.parse_args() @@ -123,7 +141,9 @@ def main(): extractor = Fbank(FbankConfig(num_mel_bins=80)) num_jobs = min(16, os.cpu_count()) - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/csj/ASR/local/compute_fbank_musan.py b/egs/csj/ASR/local/compute_fbank_musan.py index f60e62c85..44a33c4eb 100644 --- a/egs/csj/ASR/local/compute_fbank_musan.py +++ b/egs/csj/ASR/local/compute_fbank_musan.py @@ -26,6 +26,7 @@ from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor + ARGPARSE_DESCRIPTION = """ This file computes fbank features of the musan dataset. It looks for manifests in the directory data/manifests. @@ -83,7 +84,9 @@ def compute_fbank_musan(manifest_dir: Path, fbank_dir: Path): # create chunks of Musan with duration 5 - 10 seconds musan_cuts = ( CutSet.from_manifests( - recordings=combine(part["recordings"] for part in manifests.values()) + recordings=combine( + part["recordings"] for part in manifests.values() + ) ) .cut_into_windows(10.0) .filter(lambda c: c.duration > 5) @@ -104,15 +107,21 @@ def get_args(): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument("--manifest-dir", type=Path, help="Path to save manifests") - parser.add_argument("--fbank-dir", type=Path, help="Path to save fbank features") + parser.add_argument( + "--manifest-dir", type=Path, help="Path to save manifests" + ) + parser.add_argument( + "--fbank-dir", type=Path, help="Path to save fbank features" + ) return parser.parse_args() if __name__ == "__main__": args = get_args() - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_musan(args.manifest_dir, args.fbank_dir) diff --git a/egs/csj/ASR/local/conf/disfluent.ini b/egs/csj/ASR/local/conf/disfluent.ini index c987e72c5..eb70673de 100644 --- a/egs/csj/ASR/local/conf/disfluent.ini +++ b/egs/csj/ASR/local/conf/disfluent.ini @@ -1,17 +1,17 @@ ; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj +; # lhotse prepare csj [SEGMENTS] ; # Allowed period of nonverbal noise. If exceeded, a new segment is created. gap = 0.5 ; # Maximum length of segment (s). maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. +; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = +; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. +; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. +; # If you intend to use a multicharacter string for gap_sym, remember to register the +; # multicharacter string as part of userdef-string in prepare_lang_char.py. +gap_sym = [CONSTANTS] ; # Name of this mode @@ -115,59 +115,59 @@ B^ = 0 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', +; # Example: 'コク(笑 サイ+(D オン))', 笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete -; # Example: '(泣 ドンナニ)' +; # Example: '(泣 ドンナニ)' 泣 = 0 泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete -; # Example: 'シャ(咳 リン) ノ' +; # Example: 'シャ(咳 リン) ノ' 咳 = 0 ; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' 咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete -; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' +; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 ; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' L^ = 0 [REPLACEMENTS] ; # ボーカルフライなどで母音が同定できない場合 - = + = ; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = + = ; # 非語彙的な母音の引き延ばし - = + = ; # 非語彙的な子音の引き延ばし - = + = ; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = +<笑> = ; # 言語音と独立に講演者の咳が生じている場合 -<咳> = +<咳> = ; # 言語音と独立に講演者の息が生じている場合 -<息> = +<息> = ; # 講演者の泣き声 -<泣> = +<泣> = ; # 聴衆(司会者なども含む)の発話 -<フロア発話> = +<フロア発話> = ; # 聴衆の笑い -<フロア笑> = +<フロア笑> = ; # 聴衆の拍手 -<拍手> = +<拍手> = ; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = +<デモ> = ; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = +<ベル> = ; # 転記単位全体が再度読み直された場合 -<朗読間違い> = +<朗読間違い> = ; # 上記以外の音で特に目立った音 -<雑音> = +<雑音> = ; # 0.2秒以上のポーズ -

= +

= ; # Redacted information, for R ; # It is \x00D7 multiplication sign, not your normal 'x' × = × @@ -318,3 +318,4 @@ spk_id = 2 ャ = ǐa ュ = ǐu ョ = ǐo + diff --git a/egs/csj/ASR/local/conf/fluent.ini b/egs/csj/ASR/local/conf/fluent.ini index f7f27f5bc..5d22f9eb8 100644 --- a/egs/csj/ASR/local/conf/fluent.ini +++ b/egs/csj/ASR/local/conf/fluent.ini @@ -1,17 +1,17 @@ ; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj +; # lhotse prepare csj [SEGMENTS] ; # Allowed period of nonverbal noise. If exceeded, a new segment is created. gap = 0.5 ; # Maximum length of segment (s). maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. +; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = +; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. +; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. +; # If you intend to use a multicharacter string for gap_sym, remember to register the +; # multicharacter string as part of userdef-string in prepare_lang_char.py. +gap_sym = [CONSTANTS] ; # Name of this mode @@ -115,59 +115,59 @@ B^ = 0 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', +; # Example: 'コク(笑 サイ+(D オン))', 笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete -; # Example: '(泣 ドンナニ)' +; # Example: '(泣 ドンナニ)' 泣 = 0 泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete -; # Example: 'シャ(咳 リン) ノ' +; # Example: 'シャ(咳 リン) ノ' 咳 = 0 ; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' 咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete -; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' +; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 ; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' L^ = 0 [REPLACEMENTS] ; # ボーカルフライなどで母音が同定できない場合 - = + = ; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = + = ; # 非語彙的な母音の引き延ばし - = + = ; # 非語彙的な子音の引き延ばし - = + = ; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = +<笑> = ; # 言語音と独立に講演者の咳が生じている場合 -<咳> = +<咳> = ; # 言語音と独立に講演者の息が生じている場合 -<息> = +<息> = ; # 講演者の泣き声 -<泣> = +<泣> = ; # 聴衆(司会者なども含む)の発話 -<フロア発話> = +<フロア発話> = ; # 聴衆の笑い -<フロア笑> = +<フロア笑> = ; # 聴衆の拍手 -<拍手> = +<拍手> = ; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = +<デモ> = ; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = +<ベル> = ; # 転記単位全体が再度読み直された場合 -<朗読間違い> = +<朗読間違い> = ; # 上記以外の音で特に目立った音 -<雑音> = +<雑音> = ; # 0.2秒以上のポーズ -

= +

= ; # Redacted information, for R ; # It is \x00D7 multiplication sign, not your normal 'x' × = × @@ -318,3 +318,4 @@ spk_id = 2 ャ = ǐa ュ = ǐu ョ = ǐo + diff --git a/egs/csj/ASR/local/conf/number.ini b/egs/csj/ASR/local/conf/number.ini index cf9038f62..2613c3409 100644 --- a/egs/csj/ASR/local/conf/number.ini +++ b/egs/csj/ASR/local/conf/number.ini @@ -1,17 +1,17 @@ ; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj +; # lhotse prepare csj [SEGMENTS] ; # Allowed period of nonverbal noise. If exceeded, a new segment is created. gap = 0.5 ; # Maximum length of segment (s). maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. +; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = +; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. +; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. +; # If you intend to use a multicharacter string for gap_sym, remember to register the +; # multicharacter string as part of userdef-string in prepare_lang_char.py. +gap_sym = [CONSTANTS] ; # Name of this mode @@ -115,59 +115,59 @@ B^ = 0 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', +; # Example: 'コク(笑 サイ+(D オン))', 笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete -; # Example: '(泣 ドンナニ)' +; # Example: '(泣 ドンナニ)' 泣 = 0 泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete -; # Example: 'シャ(咳 リン) ノ' +; # Example: 'シャ(咳 リン) ノ' 咳 = 0 ; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' 咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete -; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' +; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 ; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' L^ = 0 [REPLACEMENTS] ; # ボーカルフライなどで母音が同定できない場合 - = + = ; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = + = ; # 非語彙的な母音の引き延ばし - = + = ; # 非語彙的な子音の引き延ばし - = + = ; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = +<笑> = ; # 言語音と独立に講演者の咳が生じている場合 -<咳> = +<咳> = ; # 言語音と独立に講演者の息が生じている場合 -<息> = +<息> = ; # 講演者の泣き声 -<泣> = +<泣> = ; # 聴衆(司会者なども含む)の発話 -<フロア発話> = +<フロア発話> = ; # 聴衆の笑い -<フロア笑> = +<フロア笑> = ; # 聴衆の拍手 -<拍手> = +<拍手> = ; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = +<デモ> = ; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = +<ベル> = ; # 転記単位全体が再度読み直された場合 -<朗読間違い> = +<朗読間違い> = ; # 上記以外の音で特に目立った音 -<雑音> = +<雑音> = ; # 0.2秒以上のポーズ -

= +

= ; # Redacted information, for R ; # It is \x00D7 multiplication sign, not your normal 'x' × = × @@ -318,3 +318,4 @@ spk_id = 2 ャ = ǐa ュ = ǐu ョ = ǐo + diff --git a/egs/csj/ASR/local/conf/symbol.ini b/egs/csj/ASR/local/conf/symbol.ini index f9801284b..8ba451dd5 100644 --- a/egs/csj/ASR/local/conf/symbol.ini +++ b/egs/csj/ASR/local/conf/symbol.ini @@ -1,17 +1,17 @@ ; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj +; # lhotse prepare csj [SEGMENTS] ; # Allowed period of nonverbal noise. If exceeded, a new segment is created. gap = 0.5 ; # Maximum length of segment (s). maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. +; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = +; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. +; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. +; # If you intend to use a multicharacter string for gap_sym, remember to register the +; # multicharacter string as part of userdef-string in prepare_lang_char.py. +gap_sym = [CONSTANTS] ; # Name of this mode @@ -116,59 +116,59 @@ B^ = 0 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', +; # Example: 'コク(笑 サイ+(D オン))', 笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete -; # Example: '(泣 ドンナニ)' +; # Example: '(泣 ドンナニ)' 泣 = 0 泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete -; # Example: 'シャ(咳 リン) ノ' +; # Example: 'シャ(咳 リン) ノ' 咳 = 0 ; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' 咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete -; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' +; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 ; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' L^ = 0 [REPLACEMENTS] ; # ボーカルフライなどで母音が同定できない場合 - = + = ; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = + = ; # 非語彙的な母音の引き延ばし - = + = ; # 非語彙的な子音の引き延ばし - = + = ; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = +<笑> = ; # 言語音と独立に講演者の咳が生じている場合 -<咳> = +<咳> = ; # 言語音と独立に講演者の息が生じている場合 -<息> = +<息> = ; # 講演者の泣き声 -<泣> = +<泣> = ; # 聴衆(司会者なども含む)の発話 -<フロア発話> = +<フロア発話> = ; # 聴衆の笑い -<フロア笑> = +<フロア笑> = ; # 聴衆の拍手 -<拍手> = +<拍手> = ; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = +<デモ> = ; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = +<ベル> = ; # 転記単位全体が再度読み直された場合 -<朗読間違い> = +<朗読間違い> = ; # 上記以外の音で特に目立った音 -<雑音> = +<雑音> = ; # 0.2秒以上のポーズ -

= +

= ; # Redacted information, for R ; # It is \x00D7 multiplication sign, not your normal 'x' × = × @@ -319,3 +319,4 @@ spk_id = 2 ャ = ǐa ュ = ǐu ョ = ǐo + diff --git a/egs/csj/ASR/local/display_manifest_statistics.py b/egs/csj/ASR/local/display_manifest_statistics.py index c043cf853..c9de21073 100644 --- a/egs/csj/ASR/local/display_manifest_statistics.py +++ b/egs/csj/ASR/local/display_manifest_statistics.py @@ -37,7 +37,9 @@ def get_parser(): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument("--manifest-dir", type=Path, help="Path to cutset manifests") + parser.add_argument( + "--manifest-dir", type=Path, help="Path to cutset manifests" + ) return parser.parse_args() diff --git a/egs/csj/ASR/local/prepare_lang_char.py b/egs/csj/ASR/local/prepare_lang_char.py index f0078421b..e4d996871 100644 --- a/egs/csj/ASR/local/prepare_lang_char.py +++ b/egs/csj/ASR/local/prepare_lang_char.py @@ -68,7 +68,8 @@ def get_args(): type=Path, default=None, help=( - "Name of lang dir. If not set, this will default to lang_char_{trans-mode}" + "Name of lang dir. " + "If not set, this will default to lang_char_{trans-mode}" ), ) @@ -86,7 +87,9 @@ def main(): args = get_args() logging.basicConfig( - format="%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s", + format=( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] " "%(message)s" + ), level=logging.INFO, ) @@ -108,7 +111,8 @@ def main(): words = set() logging.info( - f"Creating vocabulary from {args.train_cut.name} at {args.trans_mode} mode." + f"Creating vocabulary from {args.train_cut.name}" + f" at {args.trans_mode} mode." ) for cut in train_set: try: @@ -119,7 +123,8 @@ def main(): ) except KeyError: raise KeyError( - f"Could not find {args.trans_mode} in {cut.supervisions[0].custom}" + f"Could not find {args.trans_mode} in " + f"{cut.supervisions[0].custom}" ) for t in text.split(): if t in args.userdef_string: @@ -138,7 +143,9 @@ def main(): (args.lang_dir / "words_len").write_text(f"{len(words)}") - (args.lang_dir / "userdef_string").write_text("\n".join(args.userdef_string)) + (args.lang_dir / "userdef_string").write_text( + "\n".join(args.userdef_string) + ) (args.lang_dir / "trans_mode").write_text(args.trans_mode) logging.info("Done.") diff --git a/egs/csj/ASR/local/validate_manifest.py b/egs/csj/ASR/local/validate_manifest.py index 89448a49c..0c4c6c1ea 100644 --- a/egs/csj/ASR/local/validate_manifest.py +++ b/egs/csj/ASR/local/validate_manifest.py @@ -68,7 +68,8 @@ def validate_supervision_and_cut_time_bounds(c: Cut): if s.end > c.end: raise ValueError( - f"{c.id}: Supervision end time {s.end} is larger than cut end time {c.end}" + f"{c.id}: Supervision end time {s.end} is larger " + f"than cut end time {c.end}" ) @@ -88,7 +89,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py index c3e3e84bf..d78e26240 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py +++ b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py @@ -61,12 +61,10 @@ class GigaSpeechAsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description=( - "These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc." - ), + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", ) group.add_argument( "--manifest-dir", @@ -78,91 +76,75 @@ class GigaSpeechAsrDataModule: "--max-duration", type=int, default=200.0, - help=( - "Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM." - ), + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help=( - "When enabled, the batches will come from buckets of " - "similar duration (saves padding frames)." - ), + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", ) group.add_argument( "--num-buckets", type=int, default=30, - help=( - "The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets)." - ), + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help=( - "When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding." - ), + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", ) group.add_argument( "--duration-factor", type=float, default=1.0, - help=( - "Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch." - ), + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", ) group.add_argument( "--gap", type=float, default=1.0, - help=( - "The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used." - ), + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help=( - "When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available." - ), + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", ) group.add_argument( "--shuffle", type=str2bool, default=True, - help=( - "When enabled (=default), the examples will be shuffled for each epoch." - ), + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help=( - "When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it." - ), + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that collect the batches.", + help="The number of training dataloader workers that " + "collect the batches.", ) group.add_argument( @@ -176,22 +158,18 @@ class GigaSpeechAsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help=( - "Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp." - ), + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help=( - "When enabled, select noise from MUSAN and mix it " - "with training dataset. " - ), + help="When enabled, select noise from MUSAN and mix it " + "with training dataset. ", ) # GigaSpeech specific arguments @@ -205,25 +183,30 @@ class GigaSpeechAsrDataModule: "--small-dev", type=str2bool, default=False, - help="Should we use only 1000 utterances for dev (speeds up training)", + help="Should we use only 1000 utterances for dev " + "(speeds up training)", ) def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + cuts_musan = load_manifest( + self.args.manifest_dir / "musan_cuts.jsonl.gz" + ) transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix( + cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True + ) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - "Using cut concatenation with duration factor " + f"Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -238,7 +221,9 @@ class GigaSpeechAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + logging.info( + f"Time warp factor: {self.args.spec_aug_time_warp_factor}" + ) input_transforms.append( SpecAugment( time_warp_factor=self.args.spec_aug_time_warp_factor, @@ -271,7 +256,9 @@ class GigaSpeechAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -317,7 +304,9 @@ class GigaSpeechAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), return_cuts=self.args.return_cuts, ) else: @@ -373,7 +362,9 @@ class GigaSpeechAsrDataModule: @lru_cache() def dev_cuts(self) -> CutSet: logging.info("About to get dev cuts") - cuts_valid = load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") + cuts_valid = load_manifest_lazy( + self.args.manifest_dir / "cuts_DEV.jsonl.gz" + ) if self.args.small_dev: return cuts_valid.subset(first=1000) else: diff --git a/egs/gigaspeech/ASR/conformer_ctc/conformer.py b/egs/gigaspeech/ASR/conformer_ctc/conformer.py index 1153a814c..6fac07f93 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/conformer.py +++ b/egs/gigaspeech/ASR/conformer_ctc/conformer.py @@ -160,7 +160,9 @@ class ConformerEncoderLayer(nn.Module): use_conv_batchnorm: bool = False, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -180,14 +182,18 @@ class ConformerEncoderLayer(nn.Module): d_model, cnn_module_kernel, use_batchnorm=use_conv_batchnorm ) - self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm( + d_model + ) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm(d_model) # for the final output of the block + self.norm_final = nn.LayerNorm( + d_model + ) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -221,7 +227,9 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) + src = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(src) + ) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -340,7 +348,9 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -356,7 +366,9 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -626,9 +638,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -696,25 +708,33 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor" - " instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -751,7 +771,9 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -763,7 +785,9 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -797,9 +821,13 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads diff --git a/egs/gigaspeech/ASR/conformer_ctc/decode.py b/egs/gigaspeech/ASR/conformer_ctc/decode.py index b38ae9c8c..9c1418baa 100755 --- a/egs/gigaspeech/ASR/conformer_ctc/decode.py +++ b/egs/gigaspeech/ASR/conformer_ctc/decode.py @@ -62,19 +62,16 @@ def get_parser(): "--epoch", type=int, default=0, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=1, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -479,7 +476,9 @@ def decode_dataset( results[lm_scale].extend(this_batch) else: - assert len(results) > 0, "It should not decode to empty in the first batch!" + assert ( + len(results) > 0 + ), "It should not decode to empty in the first batch!" this_batch = [] hyp_words = [] for cut_id, ref_text in zip(cut_ids, texts): @@ -494,7 +493,9 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -527,7 +528,9 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + logging.info( + "Wrote detailed error stats to {}".format(errs_filename) + ) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -702,7 +705,9 @@ def main(): eos_id=eos_id, ) - save_results(params=params, test_set_name=test_set, results_dict=results_dict) + save_results( + params=params, test_set_name=test_set, results_dict=results_dict + ) logging.info("Done!") diff --git a/egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py b/egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py index 880aa76e2..ef53b77f8 100755 --- a/egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py +++ b/egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py @@ -73,7 +73,8 @@ def asr_text_post_processing(text: str) -> str: if __name__ == "__main__": parser = argparse.ArgumentParser( - description="This script evaluates GigaSpeech ASR result viaSCTK's tool sclite" + description="This script evaluates GigaSpeech ASR result via" + "SCTK's tool sclite" ) parser.add_argument( "ref", diff --git a/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py b/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py index 3b94f0c4b..cdc85ce9a 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py +++ b/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py @@ -78,10 +78,13 @@ class LabelSmoothingLoss(torch.nn.Module): ignored = target == self.ignore_index target[ignored] = 0 - true_dist = torch.nn.functional.one_hot(target, num_classes=num_classes).to(x) + true_dist = torch.nn.functional.one_hot( + target, num_classes=num_classes + ).to(x) true_dist = ( - true_dist * (1 - self.label_smoothing) + self.label_smoothing / num_classes + true_dist * (1 - self.label_smoothing) + + self.label_smoothing / num_classes ) # Set the value of ignored indexes to 0 true_dist[ignored] = 0 diff --git a/egs/gigaspeech/ASR/conformer_ctc/subsampling.py b/egs/gigaspeech/ASR/conformer_ctc/subsampling.py index 8e0f73d05..542fb0364 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/subsampling.py +++ b/egs/gigaspeech/ASR/conformer_ctc/subsampling.py @@ -42,9 +42,13 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), + nn.Conv2d( + in_channels=1, out_channels=odim, kernel_size=3, stride=2 + ), nn.ReLU(), - nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), + nn.Conv2d( + in_channels=odim, out_channels=odim, kernel_size=3, stride=2 + ), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -128,13 +132,17 @@ class VggSubsampling(nn.Module): ) ) layers.append( - torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) + torch.nn.MaxPool2d( + kernel_size=2, stride=2, padding=0, ceil_mode=True + ) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) + self.out = nn.Linear( + block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim + ) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/gigaspeech/ASR/conformer_ctc/train.py b/egs/gigaspeech/ASR/conformer_ctc/train.py index 4883d04d8..2965cde18 100755 --- a/egs/gigaspeech/ASR/conformer_ctc/train.py +++ b/egs/gigaspeech/ASR/conformer_ctc/train.py @@ -386,7 +386,9 @@ def compute_loss( # # See https://github.com/k2-fsa/icefall/issues/97 # for more details - unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) + unsorted_token_ids = graph_compiler.texts_to_ids( + supervisions["text"] + ) att_loss = mmodel.decoder_forward( encoder_memory, memory_mask, @@ -519,7 +521,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -637,7 +641,9 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/gigaspeech/ASR/conformer_ctc/transformer.py b/egs/gigaspeech/ASR/conformer_ctc/transformer.py index 0566cfc81..00ca027a7 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/transformer.py +++ b/egs/gigaspeech/ASR/conformer_ctc/transformer.py @@ -151,7 +151,9 @@ class Transformer(nn.Module): norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) + self.decoder_output_layer = torch.nn.Linear( + d_model, self.decoder_num_class + ) self.decoder_criterion = LabelSmoothingLoss() else: @@ -179,13 +181,18 @@ class Transformer(nn.Module): memory_key_padding_mask for the decoder. Its shape is (N, T). It is None if `supervision` is None. """ - if isinstance(self.use_feat_batchnorm, bool) and self.use_feat_batchnorm: + if ( + isinstance(self.use_feat_batchnorm, bool) + and self.use_feat_batchnorm + ): x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) if isinstance(self.use_feat_batchnorm, float): x *= self.use_feat_batchnorm - encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) + encoder_memory, memory_key_padding_mask = self.run_encoder( + x, supervision + ) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -266,17 +273,23 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -337,17 +350,23 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -620,7 +639,9 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) + raise RuntimeError( + "activation should be relu/gelu, not {}".format(activation) + ) class PositionalEncoding(nn.Module): @@ -822,7 +843,9 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] + lengths = [ + 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) + ] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -843,7 +866,9 @@ def encoder_padding_mask( return mask -def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: +def decoder_padding_mask( + ys_pad: torch.Tensor, ignore_id: int = -1 +) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py index 07beeb1f0..8209ee3ec 100755 --- a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py +++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py @@ -77,7 +77,9 @@ def compute_fbank_gigaspeech_dev_test(): def main(): - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_gigaspeech_dev_test() diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py index 0ee845ec8..6410249db 100755 --- a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py +++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py @@ -47,10 +47,8 @@ def get_parser(): "--batch-duration", type=float, default=600.0, - help=( - "The maximum number of audio seconds in a batch." - "Determines batch size dynamically." - ), + help="The maximum number of audio seconds in a batch." + "Determines batch size dynamically.", ) parser.add_argument( @@ -136,7 +134,9 @@ def main(): date_time = now.strftime("%Y-%m-%d-%H-%M-%S") log_filename = "log-compute_fbank_gigaspeech_splits" - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) log_filename = f"{log_filename}-{date_time}" logging.basicConfig( diff --git a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py index 31abe7fff..48d10a157 100755 --- a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py +++ b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py @@ -98,13 +98,19 @@ def preprocess_giga_speech(): f"Speed perturb for {partition} with factors 0.9 and 1.1 " "(Perturbing may take 8 minutes and saving may take 20 minutes)" ) - cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + cut_set = ( + cut_set + + cut_set.perturb_speed(0.9) + + cut_set.perturb_speed(1.1) + ) logging.info(f"Saving to {raw_cuts_path}") cut_set.to_file(raw_cuts_path) def main(): - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) preprocess_giga_speech() diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index 9ae3f071e..c87686e1e 100644 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -73,12 +73,10 @@ class GigaSpeechAsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description=( - "These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc." - ), + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", ) group.add_argument( "--manifest-dir", @@ -90,91 +88,75 @@ class GigaSpeechAsrDataModule: "--max-duration", type=int, default=200.0, - help=( - "Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM." - ), + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help=( - "When enabled, the batches will come from buckets of " - "similar duration (saves padding frames)." - ), + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", ) group.add_argument( "--num-buckets", type=int, default=30, - help=( - "The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets)." - ), + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help=( - "When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding." - ), + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", ) group.add_argument( "--duration-factor", type=float, default=1.0, - help=( - "Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch." - ), + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", ) group.add_argument( "--gap", type=float, default=1.0, - help=( - "The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used." - ), + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help=( - "When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available." - ), + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", ) group.add_argument( "--shuffle", type=str2bool, default=True, - help=( - "When enabled (=default), the examples will be shuffled for each epoch." - ), + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help=( - "When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it." - ), + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that collect the batches.", + help="The number of training dataloader workers that " + "collect the batches.", ) group.add_argument( @@ -188,22 +170,18 @@ class GigaSpeechAsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help=( - "Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp." - ), + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help=( - "When enabled, select noise from MUSAN and mix it " - "with training dataset. " - ), + help="When enabled, select noise from MUSAN and mix it " + "with training dataset. ", ) # GigaSpeech specific arguments @@ -217,7 +195,8 @@ class GigaSpeechAsrDataModule: "--small-dev", type=str2bool, default=False, - help="Should we use only 1000 utterances for dev (speeds up training)", + help="Should we use only 1000 utterances for dev " + "(speeds up training)", ) def train_dataloaders( @@ -237,16 +216,20 @@ class GigaSpeechAsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + cuts_musan = load_manifest( + self.args.manifest_dir / "musan_cuts.jsonl.gz" + ) transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix( + cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True + ) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - "Using cut concatenation with duration factor " + f"Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -261,7 +244,9 @@ class GigaSpeechAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + logging.info( + f"Time warp factor: {self.args.spec_aug_time_warp_factor}" + ) # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -304,7 +289,9 @@ class GigaSpeechAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -360,7 +347,9 @@ class GigaSpeechAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), return_cuts=self.args.return_cuts, ) else: @@ -416,7 +405,9 @@ class GigaSpeechAsrDataModule: @lru_cache() def dev_cuts(self) -> CutSet: logging.info("About to get dev cuts") - cuts_valid = load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") + cuts_valid = load_manifest_lazy( + self.args.manifest_dir / "cuts_DEV.jsonl.gz" + ) if self.args.small_dev: return cuts_valid.subset(first=1000) else: diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py index 9f5d4711b..5849a3471 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py @@ -77,7 +77,11 @@ from beam_search import ( from gigaspeech_scoring import asr_text_post_processing from train import get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import ( AttributeDict, setup_logger, @@ -114,11 +118,9 @@ def get_parser(): "--avg", type=int, default=8, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -186,7 +188,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -255,7 +258,9 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] if params.decoding_method == "fast_beam_search": @@ -270,7 +275,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -316,7 +324,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -386,7 +398,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -420,7 +434,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -496,7 +511,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py index 17f8614dc..cff9c7377 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py @@ -51,7 +51,11 @@ import sentencepiece as spm import torch from train import get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import str2bool @@ -83,11 +87,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -118,7 +120,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) return parser @@ -157,7 +160,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -205,7 +209,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py index 4d1a2356d..83ae25561 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py @@ -77,7 +77,9 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def get_parser(): @@ -176,45 +178,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -554,16 +553,23 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -726,7 +732,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") diff --git a/egs/librispeech/ASR/conformer_ctc/ali.py b/egs/librispeech/ASR/conformer_ctc/ali.py index 0169d0f82..2828e309e 100755 --- a/egs/librispeech/ASR/conformer_ctc/ali.py +++ b/egs/librispeech/ASR/conformer_ctc/ali.py @@ -61,19 +61,16 @@ def get_parser(): "--epoch", type=int, default=34, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=20, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -234,7 +231,9 @@ def compute_alignments( labels_ali = get_alignments(best_path, kind="labels") aux_labels_ali = get_alignments(best_path, kind="aux_labels") assert len(labels_ali) == len(aux_labels_ali) == len(cut_list) - for cut, labels, aux_labels in zip(cut_list, labels_ali, aux_labels_ali): + for cut, labels, aux_labels in zip( + cut_list, labels_ali, aux_labels_ali + ): cut.labels_alignment = labels_writer.store_array( key=cut.id, value=np.asarray(labels, dtype=np.int32), @@ -259,7 +258,9 @@ def compute_alignments( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return CutSet.from_cuts(cuts) @@ -288,7 +289,9 @@ def main(): out_labels_ali_filename = out_dir / f"labels_{params.dataset}.h5" out_aux_labels_ali_filename = out_dir / f"aux_labels_{params.dataset}.h5" - out_manifest_filename = out_dir / f"librispeech_cuts_{params.dataset}.jsonl.gz" + out_manifest_filename = ( + out_dir / f"librispeech_cuts_{params.dataset}.jsonl.gz" + ) for f in ( out_labels_ali_filename, diff --git a/egs/librispeech/ASR/conformer_ctc/conformer.py b/egs/librispeech/ASR/conformer_ctc/conformer.py index 1153a814c..6fac07f93 100644 --- a/egs/librispeech/ASR/conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc/conformer.py @@ -160,7 +160,9 @@ class ConformerEncoderLayer(nn.Module): use_conv_batchnorm: bool = False, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -180,14 +182,18 @@ class ConformerEncoderLayer(nn.Module): d_model, cnn_module_kernel, use_batchnorm=use_conv_batchnorm ) - self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm( + d_model + ) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm(d_model) # for the final output of the block + self.norm_final = nn.LayerNorm( + d_model + ) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -221,7 +227,9 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) + src = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(src) + ) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -340,7 +348,9 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -356,7 +366,9 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -626,9 +638,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -696,25 +708,33 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor" - " instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -751,7 +771,9 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -763,7 +785,9 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -797,9 +821,13 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 66fdf82d9..3f3b1acda 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -64,19 +64,16 @@ def get_parser(): "--epoch", type=int, default=77, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=55, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -554,7 +551,9 @@ def decode_dataset( results[lm_scale].extend(this_batch) else: - assert len(results) > 0, "It should not decode to empty in the first batch!" + assert ( + len(results) > 0 + ), "It should not decode to empty in the first batch!" this_batch = [] hyp_words = [] for ref_text in texts: @@ -569,7 +568,9 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -601,7 +602,9 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + logging.info( + "Wrote detailed error stats to {}".format(errs_filename) + ) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -806,7 +809,9 @@ def main(): eos_id=eos_id, ) - save_results(params=params, test_set_name=test_set, results_dict=results_dict) + save_results( + params=params, test_set_name=test_set, results_dict=results_dict + ) logging.info("Done!") diff --git a/egs/librispeech/ASR/conformer_ctc/export.py b/egs/librispeech/ASR/conformer_ctc/export.py index bdb8a85e5..28c28df01 100755 --- a/egs/librispeech/ASR/conformer_ctc/export.py +++ b/egs/librispeech/ASR/conformer_ctc/export.py @@ -40,20 +40,17 @@ def get_parser(): "--epoch", type=int, default=34, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=20, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -160,7 +157,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conformer_ctc/label_smoothing.py b/egs/librispeech/ASR/conformer_ctc/label_smoothing.py index cb0d6e04d..1f2f3b137 100644 --- a/egs/librispeech/ASR/conformer_ctc/label_smoothing.py +++ b/egs/librispeech/ASR/conformer_ctc/label_smoothing.py @@ -82,10 +82,13 @@ class LabelSmoothingLoss(torch.nn.Module): # for why we don't use target[ignored] = 0 here target = torch.where(ignored, torch.zeros_like(target), target) - true_dist = torch.nn.functional.one_hot(target, num_classes=num_classes).to(x) + true_dist = torch.nn.functional.one_hot( + target, num_classes=num_classes + ).to(x) true_dist = ( - true_dist * (1 - self.label_smoothing) + self.label_smoothing / num_classes + true_dist * (1 - self.label_smoothing) + + self.label_smoothing / num_classes ) # Set the value of ignored indexes to 0 diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index 8cabf1a53..a2c0a5486 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -48,11 +48,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -191,12 +189,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) return parser @@ -240,9 +236,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -303,7 +300,9 @@ def main(): logging.info("Decoding started") features = fbank(waves) - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) # Note: We don't use key padding mask for attention during decoding with torch.no_grad(): @@ -428,7 +427,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 8e0f73d05..542fb0364 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -42,9 +42,13 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), + nn.Conv2d( + in_channels=1, out_channels=odim, kernel_size=3, stride=2 + ), nn.ReLU(), - nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), + nn.Conv2d( + in_channels=odim, out_channels=odim, kernel_size=3, stride=2 + ), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -128,13 +132,17 @@ class VggSubsampling(nn.Module): ) ) layers.append( - torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) + torch.nn.MaxPool2d( + kernel_size=2, stride=2, padding=0, ceil_mode=True + ) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) + self.out = nn.Linear( + block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim + ) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 1a1c2f4c5..6419f6816 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -393,7 +393,9 @@ def compute_loss( # Works with a phone lexicon decoding_graph = graph_compiler.compile(texts) else: - raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}") + raise ValueError( + f"Unsupported type of graph compiler: {type(graph_compiler)}" + ) dense_fsa_vec = k2.DenseFsaVec( nnet_output, @@ -420,7 +422,9 @@ def compute_loss( # # See https://github.com/k2-fsa/icefall/issues/97 # for more details - unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) + unsorted_token_ids = graph_compiler.texts_to_ids( + supervisions["text"] + ) att_loss = mmodel.decoder_forward( encoder_memory, memory_mask, @@ -449,7 +453,9 @@ def compute_loss( info["utt_duration"] = supervisions["num_frames"].sum().item() # averaged padding proportion over utterances info["utt_pad_proportion"] = ( - ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)).sum().item() + ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)) + .sum() + .item() ) return loss, info @@ -562,7 +568,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -652,7 +660,7 @@ def run(rank, world_size, args): graph_compiler.eos_id = 1 else: raise ValueError( - "Unsupported type of lang dir (we expected it to have " + f"Unsupported type of lang dir (we expected it to have " f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}" ) @@ -725,7 +733,9 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/conformer_ctc/transformer.py b/egs/librispeech/ASR/conformer_ctc/transformer.py index 0566cfc81..00ca027a7 100644 --- a/egs/librispeech/ASR/conformer_ctc/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc/transformer.py @@ -151,7 +151,9 @@ class Transformer(nn.Module): norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) + self.decoder_output_layer = torch.nn.Linear( + d_model, self.decoder_num_class + ) self.decoder_criterion = LabelSmoothingLoss() else: @@ -179,13 +181,18 @@ class Transformer(nn.Module): memory_key_padding_mask for the decoder. Its shape is (N, T). It is None if `supervision` is None. """ - if isinstance(self.use_feat_batchnorm, bool) and self.use_feat_batchnorm: + if ( + isinstance(self.use_feat_batchnorm, bool) + and self.use_feat_batchnorm + ): x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) if isinstance(self.use_feat_batchnorm, float): x *= self.use_feat_batchnorm - encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) + encoder_memory, memory_key_padding_mask = self.run_encoder( + x, supervision + ) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -266,17 +273,23 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -337,17 +350,23 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -620,7 +639,9 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) + raise RuntimeError( + "activation should be relu/gelu, not {}".format(activation) + ) class PositionalEncoding(nn.Module): @@ -822,7 +843,9 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] + lengths = [ + 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) + ] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -843,7 +866,9 @@ def encoder_padding_mask( return mask -def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: +def decoder_padding_mask( + ys_pad: torch.Tensor, ignore_id: int = -1 +) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/librispeech/ASR/conformer_ctc2/attention.py b/egs/librispeech/ASR/conformer_ctc2/attention.py index 356d3f21b..1375d7245 100644 --- a/egs/librispeech/ASR/conformer_ctc2/attention.py +++ b/egs/librispeech/ASR/conformer_ctc2/attention.py @@ -18,10 +18,11 @@ from typing import Optional, Tuple import torch import torch.nn as nn -from scaling import ScaledLinear from torch import Tensor from torch.nn.init import xavier_normal_ +from scaling import ScaledLinear + class MultiheadAttention(nn.Module): r"""Allows the model to jointly attend to information @@ -75,7 +76,9 @@ class MultiheadAttention(nn.Module): self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim - self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim + self._qkv_same_embed_dim = ( + self.kdim == embed_dim and self.vdim == embed_dim + ) self.num_heads = num_heads self.dropout = dropout @@ -91,7 +94,9 @@ class MultiheadAttention(nn.Module): self.v_proj_weight = ScaledLinear(self.vdim, embed_dim, bias=bias) self.register_parameter("in_proj_weight", None) else: - self.in_proj_weight = ScaledLinear(embed_dim, 3 * embed_dim, bias=bias) + self.in_proj_weight = ScaledLinear( + embed_dim, 3 * embed_dim, bias=bias + ) self.register_parameter("q_proj_weight", None) self.register_parameter("k_proj_weight", None) self.register_parameter("v_proj_weight", None) @@ -102,8 +107,12 @@ class MultiheadAttention(nn.Module): self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=bias) if add_bias_kv: - self.bias_k = nn.Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) - self.bias_v = nn.Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) + self.bias_k = nn.Parameter( + torch.empty((1, 1, embed_dim), **factory_kwargs) + ) + self.bias_v = nn.Parameter( + torch.empty((1, 1, embed_dim), **factory_kwargs) + ) else: self.bias_k = self.bias_v = None diff --git a/egs/librispeech/ASR/conformer_ctc2/conformer.py b/egs/librispeech/ASR/conformer_ctc2/conformer.py index a6f1679ef..b906d2650 100644 --- a/egs/librispeech/ASR/conformer_ctc2/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc2/conformer.py @@ -29,8 +29,9 @@ from scaling import ( ScaledConv1d, ScaledLinear, ) -from subsampling import Conv2dSubsampling from torch import Tensor, nn +from subsampling import Conv2dSubsampling + from transformer import Supervisions, Transformer, encoder_padding_mask @@ -181,7 +182,9 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -353,7 +356,9 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -368,7 +373,9 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vecotr and `j` means the @@ -643,9 +650,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -714,25 +721,33 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor" - " instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -769,7 +784,9 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -777,9 +794,13 @@ class RelPositionMultiheadAttention(nn.Module): ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) - attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) + attn_output_weights = ( + matrix_ac + matrix_bd + ) # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -813,9 +834,13 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -838,7 +863,9 @@ class ConvolutionModule(nn.Module): """ - def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: + def __init__( + self, channels: int, kernel_size: int, bias: bool = True + ) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/conformer_ctc2/decode.py b/egs/librispeech/ASR/conformer_ctc2/decode.py index 934177b1f..97f2f2d39 100755 --- a/egs/librispeech/ASR/conformer_ctc2/decode.py +++ b/egs/librispeech/ASR/conformer_ctc2/decode.py @@ -90,11 +90,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -132,13 +130,11 @@ def get_parser(): "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -662,7 +658,9 @@ def decode_dataset( results[lm_scale].extend(this_batch) else: - assert len(results) > 0, "It should not decode to empty in the first batch!" + assert ( + len(results) > 0 + ), "It should not decode to empty in the first batch!" this_batch = [] hyp_words = [] for ref_text in texts: @@ -677,7 +675,9 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -709,7 +709,9 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + logging.info( + "Wrote detailed error stats to {}".format(errs_filename) + ) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -850,12 +852,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -878,12 +881,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -911,7 +915,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -981,7 +985,9 @@ def main(): eos_id=eos_id, ) - save_results(params=params, test_set_name=test_set, results_dict=results_dict) + save_results( + params=params, test_set_name=test_set, results_dict=results_dict + ) logging.info("Done!") diff --git a/egs/librispeech/ASR/conformer_ctc2/export.py b/egs/librispeech/ASR/conformer_ctc2/export.py index 0e1841d8d..584b3c3fc 100755 --- a/egs/librispeech/ASR/conformer_ctc2/export.py +++ b/egs/librispeech/ASR/conformer_ctc2/export.py @@ -47,7 +47,6 @@ import logging from pathlib import Path import torch -from conformer import Conformer from decode import get_params from icefall.checkpoint import ( @@ -56,8 +55,10 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon +from conformer import Conformer + from icefall.utils import str2bool +from icefall.lexicon import Lexicon def get_parser(): @@ -88,24 +89,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -180,12 +177,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -208,12 +206,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -241,7 +240,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -274,7 +273,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conformer_ctc2/train.py b/egs/librispeech/ASR/conformer_ctc2/train.py index 4d7137ad7..18fa3e69f 100755 --- a/egs/librispeech/ASR/conformer_ctc2/train.py +++ b/egs/librispeech/ASR/conformer_ctc2/train.py @@ -69,8 +69,8 @@ from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter -from icefall import diagnostics from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall import diagnostics from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import ( @@ -89,7 +89,9 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def get_parser(): @@ -496,7 +498,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -525,7 +531,9 @@ def compute_loss( # Works with a phone lexicon decoding_graph = graph_compiler.compile(texts) else: - raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}") + raise ValueError( + f"Unsupported type of graph compiler: {type(graph_compiler)}" + ) dense_fsa_vec = k2.DenseFsaVec( nnet_output, @@ -552,7 +560,9 @@ def compute_loss( # # See https://github.com/k2-fsa/icefall/issues/97 # for more details - unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) + unsorted_token_ids = graph_compiler.texts_to_ids( + supervisions["text"] + ) att_loss = mmodel.decoder_forward( encoder_memory, memory_mask, @@ -570,7 +580,9 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) info["ctc_loss"] = ctc_loss.detach().cpu().item() if params.att_rate != 0.0: info["att_loss"] = att_loss.detach().cpu().item() @@ -708,7 +720,8 @@ def train_one_epoch( except RuntimeError as e: if "CUDA out of memory" in str(e): logging.error( - f"failing batch size:{batch_size} failing batch names {batch_name}" + f"failing batch size:{batch_size} " + f"failing batch names {batch_name}" ) raise @@ -763,9 +776,9 @@ def train_one_epoch( f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"lr: {cur_lr:.2e}" ) - if loss_info["ctc_loss"] == float("inf") or loss_info["att_loss"] == float( - "inf" - ): + if loss_info["ctc_loss"] == float("inf") or loss_info[ + "att_loss" + ] == float("inf"): logging.error( "Your loss contains inf, something goes wrong" f"failing batch names {batch_name}" @@ -778,7 +791,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -870,7 +885,7 @@ def run(rank, world_size, args): graph_compiler.eos_id = 1 else: raise ValueError( - "Unsupported type of lang dir (we expected it to have " + f"Unsupported type of lang dir (we expected it to have " f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}" ) diff --git a/egs/librispeech/ASR/conformer_ctc2/transformer.py b/egs/librispeech/ASR/conformer_ctc2/transformer.py index d3443dc94..3ef7edc23 100644 --- a/egs/librispeech/ASR/conformer_ctc2/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc2/transformer.py @@ -21,17 +21,19 @@ from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn -from attention import MultiheadAttention from label_smoothing import LabelSmoothingLoss +from subsampling import Conv2dSubsampling +from attention import MultiheadAttention +from torch.nn.utils.rnn import pad_sequence + from scaling import ( ActivationBalancer, BasicNorm, DoubleSwish, - ScaledEmbedding, ScaledLinear, + ScaledEmbedding, ) -from subsampling import Conv2dSubsampling -from torch.nn.utils.rnn import pad_sequence + # Note: TorchScript requires Dict/List/etc. to be fully typed. Supervisions = Dict[str, torch.Tensor] @@ -208,7 +210,9 @@ class Transformer(nn.Module): x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) mask = encoder_padding_mask(x.size(0), supervisions) mask = mask.to(x.device) if mask is not None else None - x = self.encoder(x, src_key_padding_mask=mask, warmup=warmup) # (T, N, C) + x = self.encoder( + x, src_key_padding_mask=mask, warmup=warmup + ) # (T, N, C) return x, mask @@ -257,17 +261,23 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -328,17 +338,23 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -943,7 +959,9 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] + lengths = [ + 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) + ] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -964,7 +982,9 @@ def encoder_padding_mask( return mask -def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: +def decoder_padding_mask( + ys_pad: torch.Tensor, ignore_id: int = -1 +) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/librispeech/ASR/conformer_mmi/conformer.py b/egs/librispeech/ASR/conformer_mmi/conformer.py index 4d9ddaea9..97c8d83a2 100644 --- a/egs/librispeech/ASR/conformer_mmi/conformer.py +++ b/egs/librispeech/ASR/conformer_mmi/conformer.py @@ -156,7 +156,9 @@ class ConformerEncoderLayer(nn.Module): normalize_before: bool = True, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -174,14 +176,18 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm( + d_model + ) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm(d_model) # for the final output of the block + self.norm_final = nn.LayerNorm( + d_model + ) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -215,7 +221,9 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) + src = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(src) + ) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -334,7 +342,9 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -350,7 +360,9 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -620,9 +632,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -690,25 +702,33 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor" - " instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -745,7 +765,9 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -757,7 +779,9 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -791,9 +815,13 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -816,7 +844,9 @@ class ConvolutionModule(nn.Module): """ - def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: + def __init__( + self, channels: int, kernel_size: int, bias: bool = True + ) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/conformer_mmi/decode.py b/egs/librispeech/ASR/conformer_mmi/decode.py index e8390ded9..fc9861489 100755 --- a/egs/librispeech/ASR/conformer_mmi/decode.py +++ b/egs/librispeech/ASR/conformer_mmi/decode.py @@ -60,19 +60,16 @@ def get_parser(): "--epoch", type=int, default=34, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=20, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -481,7 +478,9 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -513,7 +512,9 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + logging.info( + "Wrote detailed error stats to {}".format(errs_filename) + ) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -652,7 +653,9 @@ def main(): if params.export: logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") + torch.save( + {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" + ) return model.to(device) @@ -684,7 +687,9 @@ def main(): eos_id=eos_id, ) - save_results(params=params, test_set_name=test_set, results_dict=results_dict) + save_results( + params=params, test_set_name=test_set, results_dict=results_dict + ) logging.info("Done!") diff --git a/egs/librispeech/ASR/conformer_mmi/subsampling.py b/egs/librispeech/ASR/conformer_mmi/subsampling.py index ad9415987..5c3e1222e 100644 --- a/egs/librispeech/ASR/conformer_mmi/subsampling.py +++ b/egs/librispeech/ASR/conformer_mmi/subsampling.py @@ -25,9 +25,13 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), + nn.Conv2d( + in_channels=1, out_channels=odim, kernel_size=3, stride=2 + ), nn.ReLU(), - nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), + nn.Conv2d( + in_channels=odim, out_channels=odim, kernel_size=3, stride=2 + ), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -111,13 +115,17 @@ class VggSubsampling(nn.Module): ) ) layers.append( - torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) + torch.nn.MaxPool2d( + kernel_size=2, stride=2, padding=0, ceil_mode=True + ) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) + self.out = nn.Linear( + block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim + ) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/librispeech/ASR/conformer_mmi/test_subsampling.py b/egs/librispeech/ASR/conformer_mmi/test_subsampling.py index d0bb017dd..937845d77 100755 --- a/egs/librispeech/ASR/conformer_mmi/test_subsampling.py +++ b/egs/librispeech/ASR/conformer_mmi/test_subsampling.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 +from subsampling import Conv2dSubsampling +from subsampling import VggSubsampling import torch -from subsampling import Conv2dSubsampling, VggSubsampling def test_conv2d_subsampling(): diff --git a/egs/librispeech/ASR/conformer_mmi/test_transformer.py b/egs/librispeech/ASR/conformer_mmi/test_transformer.py index 25d18076d..08e680607 100644 --- a/egs/librispeech/ASR/conformer_mmi/test_transformer.py +++ b/egs/librispeech/ASR/conformer_mmi/test_transformer.py @@ -1,16 +1,17 @@ #!/usr/bin/env python3 import torch -from torch.nn.utils.rnn import pad_sequence from transformer import ( Transformer, - add_eos, - add_sos, - decoder_padding_mask, encoder_padding_mask, generate_square_subsequent_mask, + decoder_padding_mask, + add_sos, + add_eos, ) +from torch.nn.utils.rnn import pad_sequence + def test_encoder_padding_mask(): supervisions = { diff --git a/egs/librispeech/ASR/conformer_mmi/train-with-attention.py b/egs/librispeech/ASR/conformer_mmi/train-with-attention.py index f8c94cff9..011dadd73 100755 --- a/egs/librispeech/ASR/conformer_mmi/train-with-attention.py +++ b/egs/librispeech/ASR/conformer_mmi/train-with-attention.py @@ -36,14 +36,23 @@ from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter from transformer import Noam -from icefall.ali import convert_alignments_to_tensor, load_alignments, lookup_alignments +from icefall.ali import ( + convert_alignments_to_tensor, + load_alignments, + lookup_alignments, +) from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.lexicon import Lexicon from icefall.mmi import LFMMILoss from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler -from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + encode_supervisions, + setup_logger, + str2bool, +) def get_parser(): @@ -361,7 +370,10 @@ def compute_loss( nnet_output = nnet_output.clone() nnet_output[:, :min_len, :] += ali_scale * mask[:, :min_len, :] - if params.batch_idx_train > params.use_ali_until and params.beam_size < 8: + if ( + params.batch_idx_train > params.use_ali_until + and params.beam_size < 8 + ): # logging.info("Change beam size to 8") params.beam_size = 8 else: @@ -750,14 +762,19 @@ def run(rank, world_size, args): for epoch in range(params.start_epoch, params.num_epochs): train_dl.sampler.set_epoch(epoch) - if params.batch_idx_train >= params.use_ali_until and train_ali is not None: + if ( + params.batch_idx_train >= params.use_ali_until + and train_ali is not None + ): # Delete the alignments to save memory train_ali = None valid_ali = None cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/conformer_mmi/train.py b/egs/librispeech/ASR/conformer_mmi/train.py index 5cfb2bfc7..9a5bdcce2 100755 --- a/egs/librispeech/ASR/conformer_mmi/train.py +++ b/egs/librispeech/ASR/conformer_mmi/train.py @@ -36,14 +36,23 @@ from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter from transformer import Noam -from icefall.ali import convert_alignments_to_tensor, load_alignments, lookup_alignments +from icefall.ali import ( + convert_alignments_to_tensor, + load_alignments, + lookup_alignments, +) from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.lexicon import Lexicon from icefall.mmi import LFMMILoss from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler -from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + encode_supervisions, + setup_logger, + str2bool, +) def get_parser(): @@ -368,7 +377,10 @@ def compute_loss( nnet_output = nnet_output.clone() nnet_output[:, :min_len, :] += ali_scale * mask[:, :min_len, :] - if params.batch_idx_train > params.use_ali_until and params.beam_size < 8: + if ( + params.batch_idx_train > params.use_ali_until + and params.beam_size < 8 + ): logging.info("Change beam size to 8") params.beam_size = 8 else: @@ -758,14 +770,19 @@ def run(rank, world_size, args): for epoch in range(params.start_epoch, params.num_epochs): fix_random_seed(params.seed + epoch) train_dl.sampler.set_epoch(epoch) - if params.batch_idx_train >= params.use_ali_until and train_ali is not None: + if ( + params.batch_idx_train >= params.use_ali_until + and train_ali is not None + ): # Delete the alignments to save memory train_ali = None valid_ali = None cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/conformer_mmi/transformer.py b/egs/librispeech/ASR/conformer_mmi/transformer.py index 2542d9abe..68a4ff65c 100644 --- a/egs/librispeech/ASR/conformer_mmi/transformer.py +++ b/egs/librispeech/ASR/conformer_mmi/transformer.py @@ -148,7 +148,9 @@ class Transformer(nn.Module): norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) + self.decoder_output_layer = torch.nn.Linear( + d_model, self.decoder_num_class + ) self.decoder_criterion = LabelSmoothingLoss(self.decoder_num_class) else: @@ -180,7 +182,9 @@ class Transformer(nn.Module): x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) - encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) + encoder_memory, memory_key_padding_mask = self.run_encoder( + x, supervision + ) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -270,7 +274,9 @@ class Transformer(nn.Module): ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -335,7 +341,9 @@ class Transformer(nn.Module): ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -608,7 +616,9 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) + raise RuntimeError( + "activation should be relu/gelu, not {}".format(activation) + ) class PositionalEncoding(nn.Module): @@ -877,7 +887,9 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] + lengths = [ + 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) + ] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -898,7 +910,9 @@ def encoder_padding_mask( return mask -def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: +def decoder_padding_mask( + ys_pad: torch.Tensor, ignore_id: int = -1 +) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py index a1c43f7f5..620d69a19 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py @@ -135,24 +135,20 @@ def get_parser(): "--avg", type=int, default=10, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -219,7 +215,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -287,7 +284,9 @@ def decode_one_batch( value=LOG_EPS, ) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] if params.decoding_method == "fast_beam_search": @@ -302,7 +301,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -348,7 +350,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -421,7 +427,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -454,7 +462,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -497,7 +506,9 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -529,12 +540,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -557,12 +569,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -590,7 +603,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py index 0639ba746..8ca7d5568 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py @@ -35,6 +35,7 @@ from scaling import ( from icefall.utils import make_pad_mask + LOG_EPSILON = math.log(1e-10) @@ -126,7 +127,9 @@ def stack_states( for si, s in enumerate(layer): attn_caches[li][si].append(s) if b == batch_size - 1: - attn_caches[li][si] = torch.stack(attn_caches[li][si], dim=1) + attn_caches[li][si] = torch.stack( + attn_caches[li][si], dim=1 + ) conv_caches = [] for layer in state_list[0][1]: @@ -265,7 +268,9 @@ class ConvolutionModule(nn.Module): intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length ) - first = torch.arange(self.chunk_length, self.chunk_length + self.cache_size) + first = torch.arange( + self.chunk_length, self.chunk_length + self.cache_size + ) indexes = intervals.unsqueeze(1) + first.unsqueeze(0) indexes = torch.cat( [indexes, torch.arange(U_ - self.cache_size, U_).unsqueeze(0)] @@ -279,7 +284,9 @@ class ConvolutionModule(nn.Module): # (num_chunks * B, cache_size + right_context_length, D) return pad_right_context.permute(0, 2, 1) - def _merge_right_context(self, right_context: torch.Tensor, B: int) -> torch.Tensor: + def _merge_right_context( + self, right_context: torch.Tensor, B: int + ) -> torch.Tensor: """ Args: right_context: @@ -330,8 +337,12 @@ class ConvolutionModule(nn.Module): right_context = x[:, :, :R] # (B, D, R) # make causal convolution - cache = torch.zeros(B, D, self.cache_size, device=x.device, dtype=x.dtype) - pad_utterance = torch.cat([cache, utterance], dim=2) # (B, D, cache + U) + cache = torch.zeros( + B, D, self.cache_size, device=x.device, dtype=x.dtype + ) + pad_utterance = torch.cat( + [cache, utterance], dim=2 + ) # (B, D, cache + U) # depth-wise conv on utterance utterance = self.depthwise_conv(pad_utterance) # (B, D, U) @@ -344,7 +355,9 @@ class ConvolutionModule(nn.Module): right_context = self.depthwise_conv( pad_right_context ) # (num_segs * B, D, right_context_length) - right_context = self._merge_right_context(right_context, B) # (B, D, R) + right_context = self._merge_right_context( + right_context, B + ) # (B, D, R) x = torch.cat([right_context, utterance], dim=2) # (B, D, R + U) x = self.deriv_balancer2(x) @@ -445,7 +458,8 @@ class EmformerAttention(nn.Module): if embed_dim % nhead != 0: raise ValueError( - f"embed_dim ({embed_dim}) is not a multiple ofnhead ({nhead})." + f"embed_dim ({embed_dim}) is not a multiple of" + f"nhead ({nhead})." ) self.embed_dim = embed_dim @@ -455,7 +469,9 @@ class EmformerAttention(nn.Module): self.head_dim = embed_dim // nhead self.dropout = dropout - self.emb_to_key_value = ScaledLinear(embed_dim, 2 * embed_dim, bias=True) + self.emb_to_key_value = ScaledLinear( + embed_dim, 2 * embed_dim, bias=True + ) self.emb_to_query = ScaledLinear(embed_dim, embed_dim, bias=True) self.out_proj = ScaledLinear( embed_dim, embed_dim, bias=True, initial_scale=0.25 @@ -497,7 +513,9 @@ class EmformerAttention(nn.Module): if padding_mask is not None: Q = attention_weights.size(1) B = attention_weights.size(0) // self.nhead - attention_weights_float = attention_weights_float.view(B, self.nhead, Q, -1) + attention_weights_float = attention_weights_float.view( + B, self.nhead, Q, -1 + ) attention_weights_float = attention_weights_float.masked_fill( padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), self.negative_inf, @@ -533,7 +551,9 @@ class EmformerAttention(nn.Module): scaling = float(self.head_dim) ** -0.5 # compute query with [right_context, utterance, summary]. - query = self.emb_to_query(torch.cat([right_context, utterance, summary])) + query = self.emb_to_query( + torch.cat([right_context, utterance, summary]) + ) # compute key and value with [memory, right_context, utterance]. key, value = self.emb_to_key_value( torch.cat([memory, right_context, utterance]) @@ -544,12 +564,16 @@ class EmformerAttention(nn.Module): # [memory, right context, left context, uttrance] # this is used in inference mode key = torch.cat([key[: M + R], left_context_key, key[M + R :]]) - value = torch.cat([value[: M + R], left_context_val, value[M + R :]]) + value = torch.cat( + [value[: M + R], left_context_val, value[M + R :]] + ) Q = query.size(0) # KV = key.size(0) reshaped_query, reshaped_key, reshaped_value = [ - tensor.contiguous().view(-1, B * self.nhead, self.head_dim).transpose(0, 1) + tensor.contiguous() + .view(-1, B * self.nhead, self.head_dim) + .transpose(0, 1) for tensor in [query, key, value] ] # (B * nhead, Q or KV, head_dim) attention_weights = torch.bmm( @@ -564,7 +588,9 @@ class EmformerAttention(nn.Module): # compute attention outputs attention = torch.bmm(attention_probs, reshaped_value) assert attention.shape == (B * self.nhead, Q, self.head_dim) - attention = attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) + attention = ( + attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) + ) # apply output projection outputs = self.out_proj(attention) @@ -646,7 +672,12 @@ class EmformerAttention(nn.Module): - output of right context and utterance, with shape (R + U, B, D). - memory output, with shape (M, B, D), where M = S - 1 or M = 0. """ - (output_right_context_utterance, output_memory, _, _,) = self._forward_impl( + ( + output_right_context_utterance, + output_memory, + _, + _, + ) = self._forward_impl( utterance, right_context, summary, @@ -916,9 +947,13 @@ class EmformerEncoderLayer(nn.Module): right_context = right_context_utterance[:R] if self.use_memory: - summary = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1) + summary = self.summary_op(utterance.permute(1, 2, 0)).permute( + 2, 0, 1 + ) else: - summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) + summary = torch.empty(0).to( + dtype=utterance.dtype, device=utterance.device + ) output_right_context_utterance, output_memory = self.attention( utterance=utterance, right_context=right_context, @@ -957,10 +992,14 @@ class EmformerEncoderLayer(nn.Module): left_context_val = attn_cache[2] if self.use_memory: - summary = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1) + summary = self.summary_op(utterance.permute(1, 2, 0)).permute( + 2, 0, 1 + ) summary = summary[:1] else: - summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) + summary = torch.empty(0).to( + dtype=utterance.dtype, device=utterance.device + ) ( output_right_context_utterance, output_memory, @@ -975,7 +1014,9 @@ class EmformerEncoderLayer(nn.Module): left_context_val=left_context_val, padding_mask=padding_mask, ) - attn_cache = self._update_attn_cache(next_key, next_val, memory, attn_cache) + attn_cache = self._update_attn_cache( + next_key, next_val, memory, attn_cache + ) return output_right_context_utterance, output_memory, attn_cache def forward( @@ -1110,7 +1151,11 @@ class EmformerEncoderLayer(nn.Module): src = src + self.dropout(self.feed_forward_macaron(src)) # emformer attention module - (src_att, output_memory, attn_cache,) = self._apply_attention_module_infer( + ( + src_att, + output_memory, + attn_cache, + ) = self._apply_attention_module_infer( src, R, memory, attn_cache, padding_mask=padding_mask ) src = src + self.dropout(src_att) @@ -1250,7 +1295,9 @@ class EmformerEncoder(nn.Module): def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor: """Hard copy each chunk's right context and concat them.""" T = x.shape[0] - num_chunks = math.ceil((T - self.right_context_length) / self.chunk_length) + num_chunks = math.ceil( + (T - self.right_context_length) / self.chunk_length + ) # first (num_chunks - 1) right context block intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length @@ -1269,7 +1316,9 @@ class EmformerEncoder(nn.Module): right_context_blocks = x[indexes.reshape(-1)] return right_context_blocks - def _gen_attention_mask_col_widths(self, chunk_idx: int, U: int) -> List[int]: + def _gen_attention_mask_col_widths( + self, chunk_idx: int, U: int + ) -> List[int]: """Calculate column widths (key, value) in attention mask for the chunk_idx chunk.""" num_chunks = math.ceil(U / self.chunk_length) @@ -1430,7 +1479,9 @@ class EmformerEncoder(nn.Module): output_lengths = torch.clamp(lengths - self.right_context_length, min=0) attention_mask = self._gen_attention_mask(utterance) memory = ( - self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[:-1] + self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ + :-1 + ] if self.use_memory else torch.empty(0).to(dtype=x.dtype, device=x.device) ) @@ -1592,8 +1643,12 @@ class EmformerEncoder(nn.Module): attn_caches = [ [ torch.zeros(self.memory_size, self.d_model, device=device), - torch.zeros(self.left_context_length, self.d_model, device=device), - torch.zeros(self.left_context_length, self.d_model, device=device), + torch.zeros( + self.left_context_length, self.d_model, device=device + ), + torch.zeros( + self.left_context_length, self.d_model, device=device + ), ] for _ in range(self.num_encoder_layers) ] @@ -1638,11 +1693,17 @@ class Emformer(EncoderInterface): raise NotImplementedError( "chunk_length must be a mutiple of subsampling_factor." ) - if left_context_length != 0 and left_context_length % subsampling_factor != 0: + if ( + left_context_length != 0 + and left_context_length % subsampling_factor != 0 + ): raise NotImplementedError( "left_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) - if right_context_length != 0 and right_context_length % subsampling_factor != 0: + if ( + right_context_length != 0 + and right_context_length % subsampling_factor != 0 + ): raise NotImplementedError( "right_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) @@ -1705,7 +1766,9 @@ class Emformer(EncoderInterface): x_lens = (((x_lens - 1) >> 1) - 1) >> 1 assert x.size(0) == x_lens.max().item() - output, output_lengths = self.encoder(x, x_lens, warmup=warmup) # (T, N, C) + output, output_lengths = self.encoder( + x, x_lens, warmup=warmup + ) # (T, N, C) output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py index 59105e286..4930881ea 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py @@ -103,11 +103,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -138,20 +136,19 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) add_model_arguments(parser) @@ -184,12 +181,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -212,12 +210,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -245,7 +244,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -280,7 +279,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py index c211b215e..9494e1fc1 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py @@ -68,12 +68,14 @@ class Stream(object): elif params.decoding_method == "fast_beam_search": # feature_len is needed to get partial results. # The rnnt_decoding_stream for fast_beam_search. - self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( - decoding_graph + self.rnnt_decoding_stream: k2.RnntDecodingStream = ( + k2.RnntDecodingStream(decoding_graph) ) self.hyp: Optional[List[int]] = None else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) self.ground_truth: str = "" diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py index abe83732a..61dbe8658 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py @@ -113,9 +113,8 @@ def get_parser(): "--epoch", type=int, default=28, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( @@ -132,24 +131,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -216,7 +211,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -375,7 +371,9 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner(current_encoder_out, decoder_out, project_input=False) + logits = model.joiner( + current_encoder_out, decoder_out, project_input=False + ) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -392,7 +390,9 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -551,10 +551,14 @@ def decode_one_chunk( feature_list, batch_first=True, padding_value=LOG_EPSILON ).to(device) feature_lens = torch.tensor(feature_len_list, device=device) - num_processed_frames = torch.tensor(num_processed_frames_list, device=device) + num_processed_frames = torch.tensor( + num_processed_frames_list, device=device + ) # Make sure it has at least 1 frame after subsampling, first-and-last-frame cutting, and right context cutting # noqa - tail_length = 3 * params.subsampling_factor + params.right_context_length + 3 + tail_length = ( + 3 * params.subsampling_factor + params.right_context_length + 3 + ) if features.size(1) < tail_length: pad_length = tail_length - features.size(1) feature_lens += pad_length @@ -601,7 +605,9 @@ def decode_one_chunk( max_states=params.max_states, ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) # Update cached states of each stream state_list = unstack_states(states) @@ -776,7 +782,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -824,7 +831,9 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -858,12 +867,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -886,12 +896,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -919,7 +930,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py index a76417e5f..c07d8f76b 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py @@ -95,7 +95,9 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def add_model_arguments(parser: argparse.ArgumentParser): @@ -263,45 +265,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -637,7 +636,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -665,16 +668,23 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -861,7 +871,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -969,7 +981,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 2 ** 22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py index 9cb4a5afc..98b8290b5 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py @@ -135,24 +135,20 @@ def get_parser(): "--avg", type=int, default=10, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -219,7 +215,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -287,7 +284,9 @@ def decode_one_batch( value=LOG_EPS, ) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] if params.decoding_method == "fast_beam_search": @@ -302,7 +301,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -348,7 +350,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -421,7 +427,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -454,7 +462,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -497,7 +506,9 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -529,12 +540,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -557,12 +569,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -590,7 +603,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py index 09200f2e1..f16f5acc7 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py @@ -35,6 +35,7 @@ from scaling import ( from icefall.utils import make_pad_mask + LOG_EPSILON = math.log(1e-10) @@ -126,7 +127,9 @@ def stack_states( for si, s in enumerate(layer): attn_caches[li][si].append(s) if b == batch_size - 1: - attn_caches[li][si] = torch.stack(attn_caches[li][si], dim=1) + attn_caches[li][si] = torch.stack( + attn_caches[li][si], dim=1 + ) conv_caches = [] for layer in state_list[0][1]: @@ -265,7 +268,9 @@ class ConvolutionModule(nn.Module): intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length ) - first = torch.arange(self.chunk_length, self.chunk_length + self.cache_size) + first = torch.arange( + self.chunk_length, self.chunk_length + self.cache_size + ) indexes = intervals.unsqueeze(1) + first.unsqueeze(0) indexes = torch.cat( [indexes, torch.arange(U_ - self.cache_size, U_).unsqueeze(0)] @@ -279,7 +284,9 @@ class ConvolutionModule(nn.Module): # (num_chunks * B, cache_size + right_context_length, D) return pad_right_context.permute(0, 2, 1) - def _merge_right_context(self, right_context: torch.Tensor, B: int) -> torch.Tensor: + def _merge_right_context( + self, right_context: torch.Tensor, B: int + ) -> torch.Tensor: """ Args: right_context: @@ -330,8 +337,12 @@ class ConvolutionModule(nn.Module): right_context = x[:, :, :R] # (B, D, R) # make causal convolution - cache = torch.zeros(B, D, self.cache_size, device=x.device, dtype=x.dtype) - pad_utterance = torch.cat([cache, utterance], dim=2) # (B, D, cache + U) + cache = torch.zeros( + B, D, self.cache_size, device=x.device, dtype=x.dtype + ) + pad_utterance = torch.cat( + [cache, utterance], dim=2 + ) # (B, D, cache + U) # depth-wise conv on utterance utterance = self.depthwise_conv(pad_utterance) # (B, D, U) @@ -344,7 +355,9 @@ class ConvolutionModule(nn.Module): right_context = self.depthwise_conv( pad_right_context ) # (num_segs * B, D, right_context_length) - right_context = self._merge_right_context(right_context, B) # (B, D, R) + right_context = self._merge_right_context( + right_context, B + ) # (B, D, R) x = torch.cat([right_context, utterance], dim=2) # (B, D, R + U) x = self.deriv_balancer2(x) @@ -445,7 +458,8 @@ class EmformerAttention(nn.Module): if embed_dim % nhead != 0: raise ValueError( - f"embed_dim ({embed_dim}) is not a multiple ofnhead ({nhead})." + f"embed_dim ({embed_dim}) is not a multiple of" + f"nhead ({nhead})." ) self.embed_dim = embed_dim @@ -455,7 +469,9 @@ class EmformerAttention(nn.Module): self.head_dim = embed_dim // nhead self.dropout = dropout - self.emb_to_key_value = ScaledLinear(embed_dim, 2 * embed_dim, bias=True) + self.emb_to_key_value = ScaledLinear( + embed_dim, 2 * embed_dim, bias=True + ) self.emb_to_query = ScaledLinear(embed_dim, embed_dim, bias=True) self.out_proj = ScaledLinear( embed_dim, embed_dim, bias=True, initial_scale=0.25 @@ -497,7 +513,9 @@ class EmformerAttention(nn.Module): if padding_mask is not None: Q = attention_weights.size(1) B = attention_weights.size(0) // self.nhead - attention_weights_float = attention_weights_float.view(B, self.nhead, Q, -1) + attention_weights_float = attention_weights_float.view( + B, self.nhead, Q, -1 + ) attention_weights_float = attention_weights_float.masked_fill( padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), self.negative_inf, @@ -543,12 +561,16 @@ class EmformerAttention(nn.Module): # [memory, right context, left context, uttrance] # this is used in inference mode key = torch.cat([key[: M + R], left_context_key, key[M + R :]]) - value = torch.cat([value[: M + R], left_context_val, value[M + R :]]) + value = torch.cat( + [value[: M + R], left_context_val, value[M + R :]] + ) Q = query.size(0) # KV = key.size(0) reshaped_query, reshaped_key, reshaped_value = [ - tensor.contiguous().view(-1, B * self.nhead, self.head_dim).transpose(0, 1) + tensor.contiguous() + .view(-1, B * self.nhead, self.head_dim) + .transpose(0, 1) for tensor in [query, key, value] ] # (B * nhead, Q or KV, head_dim) attention_weights = torch.bmm( @@ -563,7 +585,9 @@ class EmformerAttention(nn.Module): # compute attention outputs attention = torch.bmm(attention_probs, reshaped_value) assert attention.shape == (B * self.nhead, Q, self.head_dim) - attention = attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) + attention = ( + attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) + ) # apply output projection output_right_context_utterance = self.out_proj(attention) @@ -881,11 +905,13 @@ class EmformerEncoderLayer(nn.Module): right_context = right_context_utterance[:R] if self.use_memory: - memory = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ - :-1, :, : - ] + memory = self.summary_op(utterance.permute(1, 2, 0)).permute( + 2, 0, 1 + )[:-1, :, :] else: - memory = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) + memory = torch.empty(0).to( + dtype=utterance.dtype, device=utterance.device + ) output_right_context_utterance = self.attention( utterance=utterance, right_context=right_context, @@ -922,12 +948,18 @@ class EmformerEncoderLayer(nn.Module): left_context_val = attn_cache[2] if self.use_memory: - memory = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ - :1, :, : - ] + memory = self.summary_op(utterance.permute(1, 2, 0)).permute( + 2, 0, 1 + )[:1, :, :] else: - memory = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) - (output_right_context_utterance, next_key, next_val,) = self.attention.infer( + memory = torch.empty(0).to( + dtype=utterance.dtype, device=utterance.device + ) + ( + output_right_context_utterance, + next_key, + next_val, + ) = self.attention.infer( utterance=utterance, right_context=right_context, memory=pre_memory, @@ -935,7 +967,9 @@ class EmformerEncoderLayer(nn.Module): left_context_val=left_context_val, padding_mask=padding_mask, ) - attn_cache = self._update_attn_cache(next_key, next_val, memory, attn_cache) + attn_cache = self._update_attn_cache( + next_key, next_val, memory, attn_cache + ) return output_right_context_utterance, attn_cache def forward( @@ -1192,7 +1226,9 @@ class EmformerEncoder(nn.Module): def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor: """Hard copy each chunk's right context and concat them.""" T = x.shape[0] - num_chunks = math.ceil((T - self.right_context_length) / self.chunk_length) + num_chunks = math.ceil( + (T - self.right_context_length) / self.chunk_length + ) # first (num_chunks - 1) right context block intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length @@ -1211,7 +1247,9 @@ class EmformerEncoder(nn.Module): right_context_blocks = x[indexes.reshape(-1)] return right_context_blocks - def _gen_attention_mask_col_widths(self, chunk_idx: int, U: int) -> List[int]: + def _gen_attention_mask_col_widths( + self, chunk_idx: int, U: int + ) -> List[int]: """Calculate column widths (key, value) in attention mask for the chunk_idx chunk.""" num_chunks = math.ceil(U / self.chunk_length) @@ -1511,8 +1549,12 @@ class EmformerEncoder(nn.Module): attn_caches = [ [ torch.zeros(self.memory_size, self.d_model, device=device), - torch.zeros(self.left_context_length, self.d_model, device=device), - torch.zeros(self.left_context_length, self.d_model, device=device), + torch.zeros( + self.left_context_length, self.d_model, device=device + ), + torch.zeros( + self.left_context_length, self.d_model, device=device + ), ] for _ in range(self.num_encoder_layers) ] @@ -1557,11 +1599,17 @@ class Emformer(EncoderInterface): raise NotImplementedError( "chunk_length must be a mutiple of subsampling_factor." ) - if left_context_length != 0 and left_context_length % subsampling_factor != 0: + if ( + left_context_length != 0 + and left_context_length % subsampling_factor != 0 + ): raise NotImplementedError( "left_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) - if right_context_length != 0 and right_context_length % subsampling_factor != 0: + if ( + right_context_length != 0 + and right_context_length % subsampling_factor != 0 + ): raise NotImplementedError( "right_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) @@ -1624,7 +1672,9 @@ class Emformer(EncoderInterface): x_lens = (((x_lens - 1) >> 1) - 1) >> 1 assert x.size(0) == x_lens.max().item() - output, output_lengths = self.encoder(x, x_lens, warmup=warmup) # (T, N, C) + output, output_lengths = self.encoder( + x, x_lens, warmup=warmup + ) # (T, N, C) output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py index 4d05b367c..ab15e0241 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py @@ -103,11 +103,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -138,20 +136,19 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) add_model_arguments(parser) @@ -184,12 +181,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -212,12 +210,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -245,7 +244,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -280,7 +279,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py index 0486ac2eb..71150392d 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py @@ -113,9 +113,8 @@ def get_parser(): "--epoch", type=int, default=28, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( @@ -132,24 +131,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -216,7 +211,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -375,7 +371,9 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner(current_encoder_out, decoder_out, project_input=False) + logits = model.joiner( + current_encoder_out, decoder_out, project_input=False + ) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -392,7 +390,9 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -551,10 +551,14 @@ def decode_one_chunk( feature_list, batch_first=True, padding_value=LOG_EPSILON ).to(device) feature_lens = torch.tensor(feature_len_list, device=device) - num_processed_frames = torch.tensor(num_processed_frames_list, device=device) + num_processed_frames = torch.tensor( + num_processed_frames_list, device=device + ) # Make sure it has at least 1 frame after subsampling, first-and-last-frame cutting, and right context cutting # noqa - tail_length = 3 * params.subsampling_factor + params.right_context_length + 3 + tail_length = ( + 3 * params.subsampling_factor + params.right_context_length + 3 + ) if features.size(1) < tail_length: pad_length = tail_length - features.size(1) feature_lens += pad_length @@ -601,7 +605,9 @@ def decode_one_chunk( max_states=params.max_states, ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) # Update cached states of each stream state_list = unstack_states(states) @@ -776,7 +782,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -824,7 +831,9 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -858,12 +867,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -886,12 +896,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -919,7 +930,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py index 2c2593b56..2bbc45d78 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py @@ -95,7 +95,9 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def add_model_arguments(parser: argparse.ArgumentParser): @@ -263,45 +265,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -637,7 +636,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -665,16 +668,23 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -861,7 +871,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -969,7 +981,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 2 ** 22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/local/add_alignment_librispeech.py b/egs/librispeech/ASR/local/add_alignment_librispeech.py index cc34a72d8..fe6a26c51 100755 --- a/egs/librispeech/ASR/local/add_alignment_librispeech.py +++ b/egs/librispeech/ASR/local/add_alignment_librispeech.py @@ -157,7 +157,9 @@ def add_alignment( for ali_path in part_ali_dir.rglob("*.alignment.txt"): ali = parse_alignments(ali_path) alignments.update(ali) - logging.info(f"{part} has {len(alignments.keys())} cuts with alignments.") + logging.info( + f"{part} has {len(alignments.keys())} cuts with alignments." + ) # add alignment attribute and write out cuts_in = load_manifest_lazy(cuts_in_path) @@ -168,14 +170,18 @@ def add_alignment( if origin_id in alignments: ali = alignments[origin_id] else: - logging.info(f"Warning: {origin_id} does not have alignment.") + logging.info( + f"Warning: {origin_id} does not have alignment." + ) ali = [] subcut.alignment = {"word": ali} writer.write(cut, flush=True) def main(): - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) parser = get_parser() diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py index df6c609bb..c628dfd53 100755 --- a/egs/librispeech/ASR/local/compile_hlg.py +++ b/egs/librispeech/ASR/local/compile_hlg.py @@ -57,7 +57,7 @@ def get_args(): return parser.parse_args() -def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa: +def compile_HLG(lang_dir: str, lm: str="G_3_gram") -> k2.Fsa: """ Args: lang_dir: @@ -159,7 +159,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/compile_lg.py b/egs/librispeech/ASR/local/compile_lg.py index 19bf3bff4..45c4b7f5f 100755 --- a/egs/librispeech/ASR/local/compile_lg.py +++ b/egs/librispeech/ASR/local/compile_lg.py @@ -132,7 +132,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py index 97750f3ea..c0c7ef8c5 100644 --- a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py +++ b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py @@ -80,7 +80,9 @@ def compute_fbank_gigaspeech_dev_test(): def main(): - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_gigaspeech_dev_test() diff --git a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py index 37fce11f4..5587106e5 100644 --- a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py +++ b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py @@ -48,10 +48,8 @@ def get_parser(): "--batch-duration", type=float, default=600.0, - help=( - "The maximum number of audio seconds in a batch." - "Determines batch size dynamically." - ), + help="The maximum number of audio seconds in a batch." + "Determines batch size dynamically.", ) parser.add_argument( @@ -146,7 +144,9 @@ def main(): date_time = now.strftime("%Y-%m-%d-%H-%M-%S") log_filename = "log-compute_fbank_gigaspeech_splits" - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) log_filename = f"{log_filename}-{date_time}" logging.basicConfig( diff --git a/egs/librispeech/ASR/local/compute_fbank_librispeech.py b/egs/librispeech/ASR/local/compute_fbank_librispeech.py index 9f8503814..ce7d087f0 100755 --- a/egs/librispeech/ASR/local/compute_fbank_librispeech.py +++ b/egs/librispeech/ASR/local/compute_fbank_librispeech.py @@ -112,7 +112,9 @@ def compute_fbank_librispeech(bpe_model: Optional[str] = None): if "train" in partition: cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + cut_set + + cut_set.perturb_speed(0.9) + + cut_set.perturb_speed(1.1) ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -126,7 +128,9 @@ def compute_fbank_librispeech(bpe_model: Optional[str] = None): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() diff --git a/egs/librispeech/ASR/local/compute_fbank_musan.py b/egs/librispeech/ASR/local/compute_fbank_musan.py index 4a4093ae4..056da29e5 100755 --- a/egs/librispeech/ASR/local/compute_fbank_musan.py +++ b/egs/librispeech/ASR/local/compute_fbank_musan.py @@ -83,7 +83,9 @@ def compute_fbank_musan(): # create chunks of Musan with duration 5 - 10 seconds musan_cuts = ( CutSet.from_manifests( - recordings=combine(part["recordings"] for part in manifests.values()) + recordings=combine( + part["recordings"] for part in manifests.values() + ) ) .cut_into_windows(10.0) .filter(lambda c: c.duration > 5) @@ -99,7 +101,9 @@ def compute_fbank_musan(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_musan() diff --git a/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py b/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py index f149b7871..133499c8b 100755 --- a/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py +++ b/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py @@ -46,19 +46,21 @@ def get_args(): parser.add_argument( "--transcript", type=str, - help=( - "The input transcript file." - "We assume that the transcript file consists of " - "lines. Each line consists of space separated words." - ), + help="The input transcript file." + "We assume that the transcript file consists of " + "lines. Each line consists of space separated words.", ) parser.add_argument("--lexicon", type=str, help="The input lexicon file.") - parser.add_argument("--oov", type=str, default="", help="The OOV word.") + parser.add_argument( + "--oov", type=str, default="", help="The OOV word." + ) return parser.parse_args() -def process_line(lexicon: Dict[str, List[str]], line: str, oov_token: str) -> None: +def process_line( + lexicon: Dict[str, List[str]], line: str, oov_token: str +) -> None: """ Args: lexicon: diff --git a/egs/librispeech/ASR/local/download_lm.py b/egs/librispeech/ASR/local/download_lm.py index 3518db524..030122aa7 100755 --- a/egs/librispeech/ASR/local/download_lm.py +++ b/egs/librispeech/ASR/local/download_lm.py @@ -87,7 +87,9 @@ def main(out_dir: str): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/filter_cuts.py b/egs/librispeech/ASR/local/filter_cuts.py index fbcc9e24a..dff98a954 100644 --- a/egs/librispeech/ASR/local/filter_cuts.py +++ b/egs/librispeech/ASR/local/filter_cuts.py @@ -79,7 +79,8 @@ def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor): total += 1 if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) removed += 1 return False @@ -124,7 +125,8 @@ def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor): ans = cut_set.filter(remove_short_and_long_utterances).to_eager() ratio = removed / total * 100 logging.info( - f"Removed {removed} cuts from {total} cuts. {ratio:.3f}% data is removed." + f"Removed {removed} cuts from {total} cuts. " + f"{ratio:.3f}% data is removed." ) return ans @@ -153,7 +155,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/generate_unique_lexicon.py b/egs/librispeech/ASR/local/generate_unique_lexicon.py index 3459c2f5a..566c0743d 100755 --- a/egs/librispeech/ASR/local/generate_unique_lexicon.py +++ b/egs/librispeech/ASR/local/generate_unique_lexicon.py @@ -91,7 +91,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/prepare_lang_bpe.py b/egs/librispeech/ASR/local/prepare_lang_bpe.py index e121aefa9..dec8a7442 100755 --- a/egs/librispeech/ASR/local/prepare_lang_bpe.py +++ b/egs/librispeech/ASR/local/prepare_lang_bpe.py @@ -150,7 +150,9 @@ def generate_lexicon( words_pieces_ids: List[List[int]] = sp.encode(words, out_type=int) # Now convert word piece IDs back to word piece strings. - words_pieces: List[List[str]] = [sp.id_to_piece(ids) for ids in words_pieces_ids] + words_pieces: List[List[str]] = [ + sp.id_to_piece(ids) for ids in words_pieces_ids + ] lexicon = [] for word, pieces in zip(words, words_pieces): diff --git a/egs/librispeech/ASR/local/prepare_lm_training_data.py b/egs/librispeech/ASR/local/prepare_lm_training_data.py index 70343fef7..5070341f1 100755 --- a/egs/librispeech/ASR/local/prepare_lm_training_data.py +++ b/egs/librispeech/ASR/local/prepare_lm_training_data.py @@ -137,7 +137,8 @@ def main(): for i in range(num_sentences): if step and i % step == 0: logging.info( - f"Processed number of lines: {i} ({i/num_sentences*100: .3f}%)" + f"Processed number of lines: {i} " + f"({i/num_sentences*100: .3f}%)" ) word_ids = sentences[i] @@ -153,14 +154,18 @@ def main(): sentence_lengths[i] = token_ids.numel() - output["sentence_lengths"] = torch.tensor(sentence_lengths, dtype=torch.int32) + output["sentence_lengths"] = torch.tensor( + sentence_lengths, dtype=torch.int32 + ) torch.save(output, args.lm_archive) logging.info(f"Saved to {args.lm_archive}") if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/preprocess_gigaspeech.py b/egs/librispeech/ASR/local/preprocess_gigaspeech.py index 8aa5e461d..077f23039 100644 --- a/egs/librispeech/ASR/local/preprocess_gigaspeech.py +++ b/egs/librispeech/ASR/local/preprocess_gigaspeech.py @@ -119,7 +119,9 @@ def preprocess_giga_speech(): def main(): - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) preprocess_giga_speech() diff --git a/egs/librispeech/ASR/local/test_prepare_lang.py b/egs/librispeech/ASR/local/test_prepare_lang.py index 74e025ad7..d4cf62bba 100755 --- a/egs/librispeech/ASR/local/test_prepare_lang.py +++ b/egs/librispeech/ASR/local/test_prepare_lang.py @@ -88,7 +88,9 @@ def test_read_lexicon(filename: str): fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.draw("L.pdf", title="L") - fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id) + fsa_disambig = lexicon_to_fst( + lexicon_disambig, phone2id=phone2id, word2id=word2id + ) fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.draw("L_disambig.pdf", title="L_disambig") diff --git a/egs/librispeech/ASR/local/validate_manifest.py b/egs/librispeech/ASR/local/validate_manifest.py index 807aaf891..7c57d629a 100755 --- a/egs/librispeech/ASR/local/validate_manifest.py +++ b/egs/librispeech/ASR/local/validate_manifest.py @@ -64,7 +64,8 @@ def validate_supervision_and_cut_time_bounds(c: Cut): if s.end > c.end: raise ValueError( - f"{c.id}: Supervision end time {s.end} is larger than cut end time {c.end}" + f"{c.id}: Supervision end time {s.end} is larger " + f"than cut end time {c.end}" ) @@ -84,7 +85,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py old mode 100644 new mode 100755 index e69de29bb..27414d717 --- a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py @@ -0,0 +1,818 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./lstm_transducer_stateless/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./lstm_transducer_stateless/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./lstm_transducer_stateless/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./lstm_transducer_stateless/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./lstm_transducer_stateless/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./lstm_transducer_stateless/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./lstm_transducer_stateless/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./lstm_transducer_stateless/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./lstm_transducer_stateless/decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./lstm_transducer_stateless/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./lstm_transducer_stateless/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./lstm_transducer_stateless/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./lstm_transducer_stateless/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="lstm_transducer_stateless/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + # tail padding here to alleviate the tail deletion problem + num_tail_padded_frames = 35 + feature = torch.nn.functional.pad( + feature, + (0, 0, 0, num_tail_padded_frames), + mode="constant", + value=LOG_EPS, + ) + feature_lens += num_tail_padded_frames + + encoder_out, encoder_out_lens, _ = model.encoder( + x=feature, x_lens=feature_lens + ) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/export.py b/egs/librispeech/ASR/lstm_transducer_stateless/export.py old mode 100644 new mode 100755 index e69de29bb..13dac6009 --- a/egs/librispeech/ASR/lstm_transducer_stateless/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/export.py @@ -0,0 +1,388 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +(1) Export to torchscript model using torch.jit.trace() + +./lstm_transducer_stateless/export.py \ + --exp-dir ./lstm_transducer_stateless/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 35 \ + --avg 10 \ + --jit-trace 1 + +It will generate 3 files: `encoder_jit_trace.pt`, +`decoder_jit_trace.pt`, and `joiner_jit_trace.pt`. + +(2) Export `model.state_dict()` + +./lstm_transducer_stateless/export.py \ + --exp-dir ./lstm_transducer_stateless/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 35 \ + --avg 10 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +To use the generated file with `lstm_transducer_stateless/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./lstm_transducer_stateless/decode.py \ + --exp-dir ./lstm_transducer_stateless/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + +Check ./pretrained.py for its usage. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +https://huggingface.co/Zengwei/icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18 + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18 + # You will find the pre-trained model in icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18/exp +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +import torch.nn as nn +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless3/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--jit-trace", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.trace. + It will generate 3 files: + - encoder_jit_trace.pt + - decoder_jit_trace.pt + - joiner_jit_trace.pt + + Check ./jit_pretrained.py for how to use them. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def export_encoder_model_jit_trace( + encoder_model: nn.Module, + encoder_filename: str, +) -> None: + """Export the given encoder model with torch.jit.trace() + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported model. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + states = encoder_model.get_init_states() + + traced_model = torch.jit.trace(encoder_model, (x, x_lens, states)) + traced_model.save(encoder_filename) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_jit_trace( + decoder_model: nn.Module, + decoder_filename: str, +) -> None: + """Export the given decoder model with torch.jit.trace() + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The input decoder model + decoder_filename: + The filename to save the exported model. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_model = torch.jit.trace(decoder_model, (y, need_pad)) + traced_model.save(decoder_filename) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_jit_trace( + joiner_model: nn.Module, + joiner_filename: str, +) -> None: + """Export the given joiner model with torch.jit.trace() + + Note: The argument project_input is fixed to True. A user should not + project the encoder_out/decoder_out by himself/herself. The exported joiner + will do that for the user. + + Args: + joiner_model: + The input joiner model + joiner_filename: + The filename to save the exported model. + + """ + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out)) + traced_model.save(joiner_filename) + logging.info(f"Saved to {joiner_filename}") + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + if params.jit_trace is True: + convert_scaled_to_non_scaled(model, inplace=True) + logging.info("Using torch.jit.trace()") + encoder_filename = params.exp_dir / "encoder_jit_trace.pt" + export_encoder_model_jit_trace(model.encoder, encoder_filename) + + decoder_filename = params.exp_dir / "decoder_jit_trace.pt" + export_decoder_model_jit_trace(model.decoder, decoder_filename) + + joiner_filename = params.exp_dir / "joiner_jit_trace.pt" + export_joiner_model_jit_trace(model.joiner, joiner_filename) + else: + logging.info("Not using torchscript") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py old mode 100644 new mode 100755 index e69de29bb..594c33e4f --- a/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py @@ -0,0 +1,322 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, either exported by `torch.jit.trace()` +or by `torch.jit.script()`, and uses them to decode waves. +You can use the following command to get the exported models: + +./lstm_transducer_stateless/export.py \ + --exp-dir ./lstm_transducer_stateless/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --jit-trace 1 + +Usage of this script: + +./lstm_transducer_stateless/jit_pretrained.py \ + --encoder-model-filename ./lstm_transducer_stateless/exp/encoder_jit_trace.pt \ + --decoder-model-filename ./lstm_transducer_stateless/exp/decoder_jit_trace.pt \ + --joiner-model-filename ./lstm_transducer_stateless/exp/joiner_jit_trace.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder torchscript model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder torchscript model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner torchscript model. ", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="Context size of the decoder model", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + decoder: torch.jit.ScriptModule, + joiner: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + context_size: int, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + decoder: + The decoder model. + joiner: + The joiner model. + encoder_out: + A 3-D tensor of shape (N, T, C) + encoder_out_lens: + A 1-D tensor of shape (N,). + context_size: + The context size of the decoder model. + Returns: + Return the decoded results for each utterance. + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + device = encoder_out.device + blank_id = 0 # hard-code to 0 + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = decoder( + decoder_input, + need_pad=torch.tensor([False]), + ).squeeze(1) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = packed_encoder_out.data[start:end] + current_encoder_out = current_encoder_out + # current_encoder_out's shape: (batch_size, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + + logits = joiner( + current_encoder_out, + decoder_out, + ) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = decoder( + decoder_input, + need_pad=torch.tensor([False]), + ) + decoder_out = decoder_out.squeeze(1) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + encoder = torch.jit.load(args.encoder_model_filename) + decoder = torch.jit.load(args.decoder_model_filename) + joiner = torch.jit.load(args.joiner_model_filename) + + encoder.eval() + decoder.eval() + joiner.eval() + + encoder.to(device) + decoder.to(device) + joiner.to(device) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + states = encoder.get_init_states(batch_size=features.size(0), device=device) + + encoder_out, encoder_out_lens, _ = encoder( + x=features, + x_lens=feature_lengths, + states=states, + ) + + hyps = greedy_search( + decoder=decoder, + joiner=joiner, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + context_size=args.context_size, + ) + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = sp.decode(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py b/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py index e69de29bb..c54a4c478 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py @@ -0,0 +1,871 @@ +# Copyright 2022 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import math +from typing import List, Optional, Tuple + +import torch +from encoder_interface import EncoderInterface +from scaling import ( + ActivationBalancer, + BasicNorm, + DoubleSwish, + ScaledConv2d, + ScaledLinear, + ScaledLSTM, +) +from torch import nn + +LOG_EPSILON = math.log(1e-10) + + +def unstack_states( + states: Tuple[torch.Tensor, torch.Tensor] +) -> List[Tuple[torch.Tensor, torch.Tensor]]: + """ + Unstack the lstm states corresponding to a batch of utterances into a list + of states, where the i-th entry is the state from the i-th utterance. + + Args: + states: + A tuple of 2 elements. + ``states[0]`` is the lstm hidden states, of a batch of utterance. + ``states[1]`` is the lstm cell states, of a batch of utterances. + + Returns: + A list of states. + ``states[i]`` is a tuple of 2 elememts of i-th utterance. + ``states[i][0]`` is the lstm hidden states of i-th utterance. + ``states[i][1]`` is the lstm cell states of i-th utterance. + """ + hidden_states, cell_states = states + + list_hidden_states = hidden_states.unbind(dim=1) + list_cell_states = cell_states.unbind(dim=1) + + ans = [ + (h.unsqueeze(1), c.unsqueeze(1)) + for (h, c) in zip(list_hidden_states, list_cell_states) + ] + return ans + + +def stack_states( + states_list: List[Tuple[torch.Tensor, torch.Tensor]] +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Stack list of lstm states corresponding to separate utterances into a single + lstm state so that it can be used as an input for lstm when those utterances + are formed into a batch. + + Args: + state_list: + Each element in state_list corresponds to the lstm state for a single + utterance. + ``states[i]`` is a tuple of 2 elememts of i-th utterance. + ``states[i][0]`` is the lstm hidden states of i-th utterance. + ``states[i][1]`` is the lstm cell states of i-th utterance. + + + Returns: + A new state corresponding to a batch of utterances. + It is a tuple of 2 elements. + ``states[0]`` is the lstm hidden states, of a batch of utterance. + ``states[1]`` is the lstm cell states, of a batch of utterances. + """ + hidden_states = torch.cat([s[0] for s in states_list], dim=1) + cell_states = torch.cat([s[1] for s in states_list], dim=1) + ans = (hidden_states, cell_states) + return ans + + +class RNN(EncoderInterface): + """ + Args: + num_features (int): + Number of input features. + subsampling_factor (int): + Subsampling factor of encoder (convolution layers before lstm layers) (default=4). # noqa + d_model (int): + Output dimension (default=512). + dim_feedforward (int): + Feedforward dimension (default=2048). + rnn_hidden_size (int): + Hidden dimension for lstm layers (default=1024). + num_encoder_layers (int): + Number of encoder layers (default=12). + dropout (float): + Dropout rate (default=0.1). + layer_dropout (float): + Dropout value for model-level warmup (default=0.075). + aux_layer_period (int): + Period of auxiliary layers used for random combiner during training. + If set to 0, will not use the random combiner (Default). + You can set a positive integer to use the random combiner, e.g., 3. + is_pnnx: + True to make this class exportable via PNNX. + """ + + def __init__( + self, + num_features: int, + subsampling_factor: int = 4, + d_model: int = 512, + dim_feedforward: int = 2048, + rnn_hidden_size: int = 1024, + num_encoder_layers: int = 12, + dropout: float = 0.1, + layer_dropout: float = 0.075, + aux_layer_period: int = 0, + is_pnnx: bool = False, + ) -> None: + super(RNN, self).__init__() + + self.num_features = num_features + self.subsampling_factor = subsampling_factor + if subsampling_factor != 4: + raise NotImplementedError("Support only 'subsampling_factor=4'.") + + # self.encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, T//subsampling_factor, d_model). + # That is, it does two things simultaneously: + # (1) subsampling: T -> T//subsampling_factor + # (2) embedding: num_features -> d_model + self.encoder_embed = Conv2dSubsampling( + num_features, + d_model, + is_pnnx=is_pnnx, + ) + + self.is_pnnx = is_pnnx + + self.num_encoder_layers = num_encoder_layers + self.d_model = d_model + self.rnn_hidden_size = rnn_hidden_size + + encoder_layer = RNNEncoderLayer( + d_model=d_model, + dim_feedforward=dim_feedforward, + rnn_hidden_size=rnn_hidden_size, + dropout=dropout, + layer_dropout=layer_dropout, + ) + self.encoder = RNNEncoder( + encoder_layer, + num_encoder_layers, + aux_layers=list( + range( + num_encoder_layers // 3, + num_encoder_layers - 1, + aux_layer_period, + ) + ) + if aux_layer_period > 0 + else None, + ) + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + warmup: float = 1.0, + ) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Args: + x: + The input tensor. Its shape is (N, T, C), where N is the batch size, + T is the sequence length, C is the feature dimension. + x_lens: + A tensor of shape (N,), containing the number of frames in `x` + before padding. + states: + A tuple of 2 tensors (optional). It is for streaming inference. + states[0] is the hidden states of all layers, + with shape of (num_layers, N, d_model); + states[1] is the cell states of all layers, + with shape of (num_layers, N, rnn_hidden_size). + warmup: + A floating point value that gradually increases from 0 throughout + training; when it is >= 1.0 we are "fully warmed up". It is used + to turn modules on sequentially. + + Returns: + A tuple of 3 tensors: + - embeddings: its shape is (N, T', d_model), where T' is the output + sequence lengths. + - lengths: a tensor of shape (batch_size,) containing the number of + frames in `embeddings` before padding. + - updated states, whose shape is the same as the input states. + """ + x = self.encoder_embed(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + # lengths = ((x_lens - 3) // 2 - 1) // 2 # issue an warning + # + # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 + if not self.is_pnnx: + lengths = (((x_lens - 3) >> 1) - 1) >> 1 + else: + lengths1 = torch.floor((x_lens - 3) / 2) + lengths = torch.floor((lengths1 - 1) / 2) + lengths = lengths.to(x_lens) + + if not torch.jit.is_tracing(): + assert x.size(0) == lengths.max().item() + + if states is None: + x = self.encoder(x, warmup=warmup)[0] + # torch.jit.trace requires returned types to be the same as annotated # noqa + new_states = (torch.empty(0), torch.empty(0)) + else: + assert not self.training + assert len(states) == 2 + if not torch.jit.is_tracing(): + # for hidden state + assert states[0].shape == ( + self.num_encoder_layers, + x.size(1), + self.d_model, + ) + # for cell state + assert states[1].shape == ( + self.num_encoder_layers, + x.size(1), + self.rnn_hidden_size, + ) + x, new_states = self.encoder(x, states) + + x = x.permute(1, 0, 2) # (T, N, C) -> (N, T, C) + return x, lengths, new_states + + @torch.jit.export + def get_init_states( + self, batch_size: int = 1, device: torch.device = torch.device("cpu") + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Get model initial states.""" + # for rnn hidden states + hidden_states = torch.zeros( + (self.num_encoder_layers, batch_size, self.d_model), device=device + ) + cell_states = torch.zeros( + (self.num_encoder_layers, batch_size, self.rnn_hidden_size), + device=device, + ) + return (hidden_states, cell_states) + + +class RNNEncoderLayer(nn.Module): + """ + RNNEncoderLayer is made up of lstm and feedforward networks. + + Args: + d_model: + The number of expected features in the input (required). + dim_feedforward: + The dimension of feedforward network model (default=2048). + rnn_hidden_size: + The hidden dimension of rnn layer. + dropout: + The dropout value (default=0.1). + layer_dropout: + The dropout value for model-level warmup (default=0.075). + """ + + def __init__( + self, + d_model: int, + dim_feedforward: int, + rnn_hidden_size: int, + dropout: float = 0.1, + layer_dropout: float = 0.075, + ) -> None: + super(RNNEncoderLayer, self).__init__() + self.layer_dropout = layer_dropout + self.d_model = d_model + self.rnn_hidden_size = rnn_hidden_size + + assert rnn_hidden_size >= d_model, (rnn_hidden_size, d_model) + self.lstm = ScaledLSTM( + input_size=d_model, + hidden_size=rnn_hidden_size, + proj_size=d_model if rnn_hidden_size > d_model else 0, + num_layers=1, + dropout=0.0, + ) + self.feed_forward = nn.Sequential( + ScaledLinear(d_model, dim_feedforward), + ActivationBalancer(channel_dim=-1), + DoubleSwish(), + nn.Dropout(dropout), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + ) + self.norm_final = BasicNorm(d_model) + + # try to ensure the output is close to zero-mean (or at least, zero-median). # noqa + self.balancer = ActivationBalancer( + channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 + ) + self.dropout = nn.Dropout(dropout) + + def forward( + self, + src: torch.Tensor, + states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + warmup: float = 1.0, + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Pass the input through the encoder layer. + + Args: + src: + The sequence to the encoder layer (required). + Its shape is (S, N, E), where S is the sequence length, + N is the batch size, and E is the feature number. + states: + A tuple of 2 tensors (optional). It is for streaming inference. + states[0] is the hidden states of all layers, + with shape of (1, N, d_model); + states[1] is the cell states of all layers, + with shape of (1, N, rnn_hidden_size). + warmup: + It controls selective bypass of of layers; if < 1.0, we will + bypass layers more frequently. + """ + src_orig = src + + warmup_scale = min(0.1 + warmup, 1.0) + # alpha = 1.0 means fully use this encoder layer, 0.0 would mean + # completely bypass it. + if self.training: + alpha = ( + warmup_scale + if torch.rand(()).item() <= (1.0 - self.layer_dropout) + else 0.1 + ) + else: + alpha = 1.0 + + # lstm module + if states is None: + src_lstm = self.lstm(src)[0] + # torch.jit.trace requires returned types be the same as annotated + new_states = (torch.empty(0), torch.empty(0)) + else: + assert not self.training + assert len(states) == 2 + if not torch.jit.is_tracing(): + # for hidden state + assert states[0].shape == (1, src.size(1), self.d_model) + # for cell state + assert states[1].shape == (1, src.size(1), self.rnn_hidden_size) + src_lstm, new_states = self.lstm(src, states) + src = self.dropout(src_lstm) + src + + # feed forward module + src = src + self.dropout(self.feed_forward(src)) + + src = self.norm_final(self.balancer(src)) + + if alpha != 1.0: + src = alpha * src + (1 - alpha) * src_orig + + return src, new_states + + +class RNNEncoder(nn.Module): + """ + RNNEncoder is a stack of N encoder layers. + + Args: + encoder_layer: + An instance of the RNNEncoderLayer() class (required). + num_layers: + The number of sub-encoder-layers in the encoder (required). + """ + + def __init__( + self, + encoder_layer: nn.Module, + num_layers: int, + aux_layers: Optional[List[int]] = None, + ) -> None: + super(RNNEncoder, self).__init__() + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + self.d_model = encoder_layer.d_model + self.rnn_hidden_size = encoder_layer.rnn_hidden_size + + self.aux_layers: List[int] = [] + self.combiner: Optional[nn.Module] = None + if aux_layers is not None: + assert len(set(aux_layers)) == len(aux_layers) + assert num_layers - 1 not in aux_layers + self.aux_layers = aux_layers + [num_layers - 1] + self.combiner = RandomCombine( + num_inputs=len(self.aux_layers), + final_weight=0.5, + pure_prob=0.333, + stddev=2.0, + ) + + def forward( + self, + src: torch.Tensor, + states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + warmup: float = 1.0, + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Pass the input through the encoder layer in turn. + + Args: + src: + The sequence to the encoder layer (required). + Its shape is (S, N, E), where S is the sequence length, + N is the batch size, and E is the feature number. + states: + A tuple of 2 tensors (optional). It is for streaming inference. + states[0] is the hidden states of all layers, + with shape of (num_layers, N, d_model); + states[1] is the cell states of all layers, + with shape of (num_layers, N, rnn_hidden_size). + warmup: + It controls selective bypass of of layers; if < 1.0, we will + bypass layers more frequently. + """ + if states is not None: + assert not self.training + assert len(states) == 2 + if not torch.jit.is_tracing(): + # for hidden state + assert states[0].shape == ( + self.num_layers, + src.size(1), + self.d_model, + ) + # for cell state + assert states[1].shape == ( + self.num_layers, + src.size(1), + self.rnn_hidden_size, + ) + + output = src + + outputs = [] + + new_hidden_states = [] + new_cell_states = [] + + for i, mod in enumerate(self.layers): + if states is None: + output = mod(output, warmup=warmup)[0] + else: + layer_state = ( + states[0][i : i + 1, :, :], # h: (1, N, d_model) + states[1][i : i + 1, :, :], # c: (1, N, rnn_hidden_size) + ) + output, (h, c) = mod(output, layer_state) + new_hidden_states.append(h) + new_cell_states.append(c) + + if self.combiner is not None and i in self.aux_layers: + outputs.append(output) + + if self.combiner is not None: + output = self.combiner(outputs) + + if states is None: + new_states = (torch.empty(0), torch.empty(0)) + else: + new_states = ( + torch.cat(new_hidden_states, dim=0), + torch.cat(new_cell_states, dim=0), + ) + + return output, new_states + + +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/4 length). + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where + T' = ((T-3)//2-1)//2, which approximates T' == T//4 + + It is based on + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + layer1_channels: int = 8, + layer2_channels: int = 32, + layer3_channels: int = 128, + is_pnnx: bool = False, + ) -> None: + """ + Args: + in_channels: + Number of channels in. The input shape is (N, T, in_channels). + Caution: It requires: T >= 9, in_channels >= 9. + out_channels + Output dim. The output shape is (N, ((T-3)//2-1)//2, out_channels) + layer1_channels: + Number of channels in layer1 + layer1_channels: + Number of channels in layer2 + is_pnnx: + True if we are converting the model to PNNX format. + False otherwise. + """ + assert in_channels >= 9 + super().__init__() + + self.conv = nn.Sequential( + ScaledConv2d( + in_channels=1, + out_channels=layer1_channels, + kernel_size=3, + padding=0, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ScaledConv2d( + in_channels=layer1_channels, + out_channels=layer2_channels, + kernel_size=3, + stride=2, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ScaledConv2d( + in_channels=layer2_channels, + out_channels=layer3_channels, + kernel_size=3, + stride=2, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ) + self.out = ScaledLinear( + layer3_channels * (((in_channels - 3) // 2 - 1) // 2), out_channels + ) + # set learn_eps=False because out_norm is preceded by `out`, and `out` + # itself has learned scale, so the extra degree of freedom is not + # needed. + self.out_norm = BasicNorm(out_channels, learn_eps=False) + # constrain median of output to be close to zero. + self.out_balancer = ActivationBalancer( + channel_dim=-1, min_positive=0.45, max_positive=0.55 + ) + + # ncnn supports only batch size == 1 + self.is_pnnx = is_pnnx + self.conv_out_dim = self.out.weight.shape[1] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + + Returns: + Return a tensor of shape (N, ((T-3)//2-1)//2, odim) + """ + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + x = self.conv(x) + + if torch.jit.is_tracing() and self.is_pnnx: + x = x.permute(0, 2, 1, 3).reshape(1, -1, self.conv_out_dim) + x = self.out(x) + else: + # Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + + # Now x is of shape (N, ((T-3)//2-1))//2, odim) + x = self.out_norm(x) + x = self.out_balancer(x) + return x + + +class RandomCombine(nn.Module): + """ + This module combines a list of Tensors, all with the same shape, to + produce a single output of that same shape which, in training time, + is a random combination of all the inputs; but which in test time + will be just the last input. + + The idea is that the list of Tensors will be a list of outputs of multiple + conformer layers. This has a similar effect as iterated loss. (See: + DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER + NETWORKS). + """ + + def __init__( + self, + num_inputs: int, + final_weight: float = 0.5, + pure_prob: float = 0.5, + stddev: float = 2.0, + ) -> None: + """ + Args: + num_inputs: + The number of tensor inputs, which equals the number of layers' + outputs that are fed into this module. E.g. in an 18-layer neural + net if we output layers 16, 12, 18, num_inputs would be 3. + final_weight: + The amount of weight or probability we assign to the + final layer when randomly choosing layers or when choosing + continuous layer weights. + pure_prob: + The probability, on each frame, with which we choose + only a single layer to output (rather than an interpolation) + stddev: + A standard deviation that we add to log-probs for computing + randomized weights. + + The method of choosing which layers, or combinations of layers, to use, + is conceptually as follows:: + + With probability `pure_prob`:: + With probability `final_weight`: choose final layer, + Else: choose random non-final layer. + Else:: + Choose initial log-weights that correspond to assigning + weight `final_weight` to the final layer and equal + weights to other layers; then add Gaussian noise + with variance `stddev` to these log-weights, and normalize + to weights (note: the average weight assigned to the + final layer here will not be `final_weight` if stddev>0). + """ + super().__init__() + assert 0 <= pure_prob <= 1, pure_prob + assert 0 < final_weight < 1, final_weight + assert num_inputs >= 1 + + self.num_inputs = num_inputs + self.final_weight = final_weight + self.pure_prob = pure_prob + self.stddev = stddev + + self.final_log_weight = ( + torch.tensor( + (final_weight / (1 - final_weight)) * (self.num_inputs - 1) + ) + .log() + .item() + ) + + def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: + """Forward function. + Args: + inputs: + A list of Tensor, e.g. from various layers of a transformer. + All must be the same shape, of (*, num_channels) + Returns: + A Tensor of shape (*, num_channels). In test mode + this is just the final input. + """ + num_inputs = self.num_inputs + assert len(inputs) == num_inputs + if not self.training or torch.jit.is_scripting(): + return inputs[-1] + + # Shape of weights: (*, num_inputs) + num_channels = inputs[0].shape[-1] + num_frames = inputs[0].numel() // num_channels + + ndim = inputs[0].ndim + # stacked_inputs: (num_frames, num_channels, num_inputs) + stacked_inputs = torch.stack(inputs, dim=ndim).reshape( + (num_frames, num_channels, num_inputs) + ) + + # weights: (num_frames, num_inputs) + weights = self._get_random_weights( + inputs[0].dtype, inputs[0].device, num_frames + ) + + weights = weights.reshape(num_frames, num_inputs, 1) + # ans: (num_frames, num_channels, 1) + ans = torch.matmul(stacked_inputs, weights) + # ans: (*, num_channels) + + ans = ans.reshape(inputs[0].shape[:-1] + (num_channels,)) + + # The following if causes errors for torch script in torch 1.6.0 + # if __name__ == "__main__": + # # for testing only... + # print("Weights = ", weights.reshape(num_frames, num_inputs)) + return ans + + def _get_random_weights( + self, dtype: torch.dtype, device: torch.device, num_frames: int + ) -> torch.Tensor: + """Return a tensor of random weights, of shape + `(num_frames, self.num_inputs)`, + Args: + dtype: + The data-type desired for the answer, e.g. float, double. + device: + The device needed for the answer. + num_frames: + The number of sets of weights desired + Returns: + A tensor of shape (num_frames, self.num_inputs), such that + `ans.sum(dim=1)` is all ones. + """ + pure_prob = self.pure_prob + if pure_prob == 0.0: + return self._get_random_mixed_weights(dtype, device, num_frames) + elif pure_prob == 1.0: + return self._get_random_pure_weights(dtype, device, num_frames) + else: + p = self._get_random_pure_weights(dtype, device, num_frames) + m = self._get_random_mixed_weights(dtype, device, num_frames) + return torch.where( + torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m + ) + + def _get_random_pure_weights( + self, dtype: torch.dtype, device: torch.device, num_frames: int + ): + """Return a tensor of random one-hot weights, of shape + `(num_frames, self.num_inputs)`, + Args: + dtype: + The data-type desired for the answer, e.g. float, double. + device: + The device needed for the answer. + num_frames: + The number of sets of weights desired. + Returns: + A one-hot tensor of shape `(num_frames, self.num_inputs)`, with + exactly one weight equal to 1.0 on each frame. + """ + final_prob = self.final_weight + + # final contains self.num_inputs - 1 in all elements + final = torch.full((num_frames,), self.num_inputs - 1, device=device) + # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. # noqa + nonfinal = torch.randint( + self.num_inputs - 1, (num_frames,), device=device + ) + + indexes = torch.where( + torch.rand(num_frames, device=device) < final_prob, final, nonfinal + ) + ans = torch.nn.functional.one_hot( + indexes, num_classes=self.num_inputs + ).to(dtype=dtype) + return ans + + def _get_random_mixed_weights( + self, dtype: torch.dtype, device: torch.device, num_frames: int + ): + """Return a tensor of random one-hot weights, of shape + `(num_frames, self.num_inputs)`, + Args: + dtype: + The data-type desired for the answer, e.g. float, double. + device: + The device needed for the answer. + num_frames: + The number of sets of weights desired. + Returns: + A tensor of shape (num_frames, self.num_inputs), which elements + in [0..1] that sum to one over the second axis, i.e. + `ans.sum(dim=1)` is all ones. + """ + logprobs = ( + torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device) + * self.stddev # noqa + ) + logprobs[:, -1] += self.final_log_weight + return logprobs.softmax(dim=1) + + +def _test_random_combine(final_weight: float, pure_prob: float, stddev: float): + print( + f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}" # noqa + ) + num_inputs = 3 + num_channels = 50 + m = RandomCombine( + num_inputs=num_inputs, + final_weight=final_weight, + pure_prob=pure_prob, + stddev=stddev, + ) + + x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)] + + y = m(x) + assert y.shape == x[0].shape + assert torch.allclose(y, x[0]) # .. since actually all ones. + + +def _test_random_combine_main(): + _test_random_combine(0.999, 0, 0.0) + _test_random_combine(0.5, 0, 0.0) + _test_random_combine(0.999, 0, 0.0) + _test_random_combine(0.5, 0, 0.3) + _test_random_combine(0.5, 1, 0.3) + _test_random_combine(0.5, 0.5, 0.3) + + feature_dim = 50 + c = RNN(num_features=feature_dim, d_model=128) + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + f = c( + torch.randn(batch_size, seq_len, feature_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + ) + f # to remove flake8 warnings + + +if __name__ == "__main__": + feature_dim = 80 + m = RNN( + num_features=feature_dim, + d_model=512, + rnn_hidden_size=1024, + dim_feedforward=2048, + num_encoder_layers=12, + ) + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + f = m( + torch.randn(batch_size, seq_len, feature_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + warmup=0.5, + ) + num_param = sum([p.numel() for p in m.parameters()]) + print(f"Number of model parameters: {num_param}") + + _test_random_combine_main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/model.py b/egs/librispeech/ASR/lstm_transducer_stateless/model.py index e69de29bb..d71132b4a 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/model.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/model.py @@ -0,0 +1,210 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Tuple + +import k2 +import torch +import torch.nn as nn +from encoder_interface import EncoderInterface +from scaling import ScaledLinear + +from icefall.utils import add_sos + + +class Transducer(nn.Module): + """It implements https://arxiv.org/pdf/1211.3711.pdf + "Sequence Transduction with Recurrent Neural Networks" + """ + + def __init__( + self, + encoder: EncoderInterface, + decoder: nn.Module, + joiner: nn.Module, + encoder_dim: int, + decoder_dim: int, + joiner_dim: int, + vocab_size: int, + ): + """ + Args: + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dm) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, decoder_dim). + It should contain one attribute: `blank_id`. + joiner: + It has two inputs with shapes: (N, T, encoder_dim) and + (N, U, decoder_dim). + Its output shape is (N, T, U, vocab_size). Note that its output + contains unnormalized probs, i.e., not processed by log-softmax. + """ + super().__init__() + assert isinstance(encoder, EncoderInterface), type(encoder) + assert hasattr(decoder, "blank_id") + + self.encoder = encoder + self.decoder = decoder + self.joiner = joiner + + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_speed=0.5 + ) + self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: k2.RaggedTensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + warmup: float = 1.0, + reduction: str = "sum", + delay_penalty: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + warmup: + A value warmup >= 0 that determines which modules are active, values + warmup > 1 "are fully warmed up" and all modules will be active. + reduction: + "sum" to sum the losses over all utterances in the batch. + "none" to return the loss in a 1-D tensor for each utterance + in the batch. + delay_penalty: + A constant value used to penalize symbol delay, to encourage + streaming models to emit symbols earlier. + See https://github.com/k2-fsa/k2/issues/955 and + https://arxiv.org/pdf/2211.00490.pdf for more details. + Returns: + Return the transducer loss. + + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs + """ + assert reduction in ("sum", "none"), reduction + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.num_axes == 2, y.num_axes + + assert x.size(0) == x_lens.size(0) == y.dim0 + + encoder_out, x_lens, _ = self.encoder(x, x_lens, warmup=warmup) + assert torch.all(x_lens > 0) + + # Now for the decoder, i.e., the prediction network + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + + blank_id = self.decoder.blank_id + sos_y = add_sos(y, sos_id=blank_id) + + # sos_y_padded: [B, S + 1], start with SOS. + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + + # decoder_out: [B, S + 1, decoder_dim] + decoder_out = self.decoder(sos_y_padded) + + # Note: y does not start with SOS + # y_padded : [B, S] + y_padded = y.pad(mode="constant", padding_value=0) + + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) + boundary[:, 2] = y_lens + boundary[:, 3] = x_lens + + lm = self.simple_lm_proj(decoder_out) + am = self.simple_am_proj(encoder_out) + + with torch.cuda.amp.autocast(enabled=False): + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=lm.float(), + am=am.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction=reduction, + delay_penalty=delay_penalty, + return_grad=True, + ) + + # ranges : [B, T, prune_range] + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=prune_range, + ) + + # am_pruned : [B, T, prune_range, encoder_dim] + # lm_pruned : [B, T, prune_range, decoder_dim] + am_pruned, lm_pruned = k2.do_rnnt_pruning( + am=self.joiner.encoder_proj(encoder_out), + lm=self.joiner.decoder_proj(decoder_out), + ranges=ranges, + ) + + # logits : [B, T, prune_range, vocab_size] + + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). + logits = self.joiner(am_pruned, lm_pruned, project_input=False) + + with torch.cuda.amp.autocast(enabled=False): + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + delay_penalty=delay_penalty, + reduction=reduction, + ) + + return (simple_loss, pruned_loss) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py old mode 100644 new mode 100755 index e69de29bb..2a6e2adc6 --- a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py @@ -0,0 +1,352 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +(1) greedy search +./lstm_transducer_stateless/pretrained.py \ + --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) beam search +./lstm_transducer_stateless/pretrained.py \ + --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) modified beam search +./lstm_transducer_stateless/pretrained.py \ + --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(4) fast beam search +./lstm_transducer_stateless/pretrained.py \ + --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method fast_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +You can also use `./lstm_transducer_stateless/exp/epoch-xx.pt`. + +Note: ./lstm_transducer_stateless/exp/pretrained.pt is generated by +./lstm_transducer_stateless/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + model.device = device + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens, _ = model.encoder( + x=features, x_lens=feature_lengths + ) + + num_waves = encoder_out.size(0) + hyps = [] + msg = f"Using {params.method}" + if params.method == "beam_search": + msg += f" with beam size {params.beam_size}" + logging.info(msg) + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + + hyps.append(sp.decode(hyp).split()) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/stream.py b/egs/librispeech/ASR/lstm_transducer_stateless/stream.py index e69de29bb..97d890c82 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/stream.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/stream.py @@ -0,0 +1,148 @@ +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple + +import k2 +import torch +from beam_search import Hypothesis, HypothesisList + +from icefall.utils import AttributeDict + + +class Stream(object): + def __init__( + self, + params: AttributeDict, + cut_id: str, + decoding_graph: Optional[k2.Fsa] = None, + device: torch.device = torch.device("cpu"), + LOG_EPS: float = math.log(1e-10), + ) -> None: + """ + Args: + params: + It's the return value of :func:`get_params`. + cut_id: + The cut id of the current stream. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + device: + The device to run this stream. + LOG_EPS: + A float value used for padding. + """ + self.LOG_EPS = LOG_EPS + self.cut_id = cut_id + + # Containing attention caches and convolution caches + self.states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + + # It uses different attributes for different decoding methods. + self.context_size = params.context_size + self.decoding_method = params.decoding_method + if params.decoding_method == "greedy_search": + self.hyp = [params.blank_id] * params.context_size + elif params.decoding_method == "modified_beam_search": + self.hyps = HypothesisList() + self.hyps.add( + Hypothesis( + ys=[params.blank_id] * params.context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) + elif params.decoding_method == "fast_beam_search": + # feature_len is needed to get partial results. + # The rnnt_decoding_stream for fast_beam_search. + self.rnnt_decoding_stream: k2.RnntDecodingStream = ( + k2.RnntDecodingStream(decoding_graph) + ) + self.hyp: Optional[List[int]] = None + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + + self.ground_truth: str = "" + + self.feature: Optional[torch.Tensor] = None + # Make sure all feature frames can be used. + # We aim to obtain 1 frame after subsampling. + self.chunk_length = params.subsampling_factor + self.pad_length = 5 + self.num_frames = 0 + self.num_processed_frames = 0 + + # After all feature frames are processed, we set this flag to True + self._done = False + + def set_feature(self, feature: torch.Tensor) -> None: + assert feature.dim() == 2, feature.dim() + # tail padding here to alleviate the tail deletion problem + num_tail_padded_frames = 35 + self.num_frames = feature.size(0) + num_tail_padded_frames + self.feature = torch.nn.functional.pad( + feature, + (0, 0, 0, self.pad_length + num_tail_padded_frames), + mode="constant", + value=self.LOG_EPS, + ) + + def get_feature_chunk(self) -> torch.Tensor: + """Get a chunk of feature frames. + + Returns: + A tensor of shape (ret_length, feature_dim). + """ + update_length = min( + self.num_frames - self.num_processed_frames, self.chunk_length + ) + ret_length = update_length + self.pad_length + + ret_feature = self.feature[ + self.num_processed_frames : self.num_processed_frames + ret_length + ] + # Cut off used frames. + # self.feature = self.feature[update_length:] + + self.num_processed_frames += update_length + if self.num_processed_frames >= self.num_frames: + self._done = True + + return ret_feature + + @property + def id(self) -> str: + return self.cut_id + + @property + def done(self) -> bool: + """Return True if all feature frames are processed.""" + return self._done + + def decoding_result(self) -> List[int]: + """Obtain current decoding result.""" + if self.decoding_method == "greedy_search": + return self.hyp[self.context_size :] + elif self.decoding_method == "modified_beam_search": + best_hyp = self.hyps.get_most_probable(length_norm=True) + return best_hyp.ys[self.context_size :] + else: + assert self.decoding_method == "fast_beam_search" + return self.hyp diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py old mode 100644 new mode 100755 index e69de29bb..d6376bdc0 --- a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py @@ -0,0 +1,968 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./lstm_transducer_stateless/streaming_decode.py \ + --epoch 35 \ + --avg 10 \ + --exp-dir lstm_transducer_stateless/exp \ + --num-decode-streams 2000 \ + --num-encoder-layers 12 \ + --rnn-hidden-size 1024 \ + --decoding-method greedy_search \ + --use-averaged-model True + +(2) modified beam search +./lstm_transducer_stateless/streaming_decode.py \ + --epoch 35 \ + --avg 10 \ + --exp-dir lstm_transducer_stateless/exp \ + --num-decode-streams 2000 \ + --num-encoder-layers 12 \ + --rnn-hidden-size 1024 \ + --decoding-method modified_beam_search \ + --use-averaged-model True \ + --beam-size 4 + +(3) fast beam search +./lstm_transducer_stateless/streaming_decode.py \ + --epoch 35 \ + --avg 10 \ + --exp-dir lstm_transducer_stateless/exp \ + --num-decode-streams 2000 \ + --num-encoder-layers 12 \ + --rnn-hidden-size 1024 \ + --decoding-method fast_beam_search \ + --use-averaged-model True \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 +""" +import argparse +import logging +import warnings +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import Hypothesis, HypothesisList, get_hyps_shape +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet +from lstm import LOG_EPSILON, stack_states, unstack_states +from stream import Stream +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.decode import one_best_decoding +from icefall.utils import ( + AttributeDict, + get_texts, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=False, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="transducer_emformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--sampling-rate", + type=float, + default=16000, + help="Sample rate of the audio", + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded in parallel", + ) + + add_model_arguments(parser) + + return parser + + +def greedy_search( + model: nn.Module, + encoder_out: torch.Tensor, + streams: List[Stream], +) -> None: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C), where N >= 1. + streams: + A list of Stream objects. + """ + assert len(streams) == encoder_out.size(0) + assert encoder_out.ndim == 3 + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = next(model.parameters()).device + T = encoder_out.size(1) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + decoder_input = torch.tensor( + [stream.hyp[-context_size:] for stream in streams], + device=device, + dtype=torch.int64, + ) + # decoder_out is of shape (batch_size, 1, decoder_out_dim) + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + for t in range(T): + # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) + current_encoder_out = encoder_out[:, t : t + 1, :] # noqa + + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + # logits'shape (batch_size, vocab_size) + logits = logits.squeeze(1).squeeze(1) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + streams[i].hyp.append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = torch.tensor( + [stream.hyp[-context_size:] for stream in streams], + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, + need_pad=False, + ) + decoder_out = model.joiner.decoder_proj(decoder_out) + + +def modified_beam_search( + model: nn.Module, + encoder_out: torch.Tensor, + streams: List[Stream], + beam: int = 4, +): + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + + Args: + model: + The RNN-T model. + encoder_out: + A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of + the encoder model. + streams: + A list of stream objects. + beam: + Number of active paths during the beam search. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert len(streams) == encoder_out.size(0) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = next(model.parameters()).device + batch_size = len(streams) + T = encoder_out.size(1) + + B = [stream.hyps for stream in streams] + + encoder_out = model.joiner.encoder_proj(encoder_out) + + for t in range(T): + current_encoder_out = encoder_out[:, t].unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.stack( + [hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0 + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, decoder_output_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, decoder_out, project_input=False + ) + # logits is of shape (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + if new_token != blank_id: + new_ys.append(new_token) + + new_log_prob = topk_log_probs[k] + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + B[i].add(new_hyp) + + for i in range(batch_size): + streams[i].hyps = B[i] + + +def fast_beam_search_one_best( + model: nn.Module, + streams: List[Stream], + encoder_out: torch.Tensor, + processed_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, +) -> None: + """It limits the maximum number of symbols per frame to 1. + + A lattice is first obtained using modified beam search, and then + the shortest path within the lattice is used as the final output. + + Args: + model: + An instance of `Transducer`. + streams: + A list of stream objects. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + processed_lens: + A tensor of shape (N,) containing the number of processed frames + in `encoder_out` before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + """ + assert encoder_out.ndim == 3 + + context_size = model.decoder.context_size + vocab_size = model.decoder.vocab_size + + B, T, C = encoder_out.shape + assert B == len(streams) + + config = k2.RnntDecodingConfig( + vocab_size=vocab_size, + decoder_history_len=context_size, + beam=beam, + max_contexts=max_contexts, + max_states=max_states, + ) + individual_streams = [] + for i in range(B): + individual_streams.append(streams[i].rnnt_decoding_stream) + decoding_streams = k2.RnntDecodingStreams(individual_streams, config) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + for t in range(T): + # shape is a RaggedShape of shape (B, context) + # contexts is a Tensor of shape (shape.NumElements(), context_size) + shape, contexts = decoding_streams.get_contexts() + # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 + contexts = contexts.to(torch.int64) + # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) + decoder_out = model.decoder(contexts, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + # current_encoder_out is of shape + # (shape.NumElements(), 1, joiner_dim) + # fmt: off + current_encoder_out = torch.index_select( + encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) + ) + # fmt: on + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + logits = logits.squeeze(1).squeeze(1) + log_probs = logits.log_softmax(dim=-1) + decoding_streams.advance(log_probs) + + decoding_streams.terminate_and_flush_to_streams() + + lattice = decoding_streams.format_output(processed_lens.tolist()) + + best_path = one_best_decoding(lattice) + hyps = get_texts(best_path) + + for i in range(B): + streams[i].hyp = hyps[i] + + +def decode_one_chunk( + model: nn.Module, + streams: List[Stream], + params: AttributeDict, + decoding_graph: Optional[k2.Fsa] = None, +) -> List[int]: + """ + Args: + model: + The Transducer model. + streams: + A list of Stream objects. + params: + It is returned by :func:`get_params`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search. + + Returns: + A list of indexes indicating the finished streams. + """ + device = next(model.parameters()).device + + feature_list = [] + feature_len_list = [] + state_list = [] + num_processed_frames_list = [] + + for stream in streams: + # We should first get `stream.num_processed_frames` + # before calling `stream.get_feature_chunk()` + # since `stream.num_processed_frames` would be updated + num_processed_frames_list.append(stream.num_processed_frames) + feature = stream.get_feature_chunk() + feature_len = feature.size(0) + feature_list.append(feature) + feature_len_list.append(feature_len) + state_list.append(stream.states) + + features = pad_sequence( + feature_list, batch_first=True, padding_value=LOG_EPSILON + ).to(device) + feature_lens = torch.tensor(feature_len_list, device=device) + num_processed_frames = torch.tensor( + num_processed_frames_list, device=device + ) + + # Make sure it has at least 1 frame after subsampling + tail_length = params.subsampling_factor + 5 + if features.size(1) < tail_length: + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPSILON, + ) + + # Stack states of all streams + states = stack_states(state_list) + + encoder_out, encoder_out_lens, states = model.encoder( + x=features, + x_lens=feature_lens, + states=states, + ) + + if params.decoding_method == "greedy_search": + greedy_search( + model=model, + streams=streams, + encoder_out=encoder_out, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=streams, + encoder_out=encoder_out, + beam=params.beam_size, + ) + elif params.decoding_method == "fast_beam_search": + # feature_len is needed to get partial results. + # The rnnt_decoding_stream for fast_beam_search. + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + processed_lens = ( + num_processed_frames // params.subsampling_factor + + encoder_out_lens + ) + fast_beam_search_one_best( + model=model, + streams=streams, + encoder_out=encoder_out, + processed_lens=processed_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + + # Update cached states of each stream + state_list = unstack_states(states) + for i, s in enumerate(state_list): + streams[i].states = s + + finished_streams = [i for i, stream in enumerate(streams) if stream.done] + return finished_streams + + +def create_streaming_feature_extractor() -> Fbank: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + return Fbank(opts) + + +def decode_dataset( + cuts: CutSet, + model: nn.Module, + params: AttributeDict, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, +): + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The Transducer model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search. + + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = next(model.parameters()).device + + log_interval = 300 + + fbank = create_streaming_feature_extractor() + + decode_results = [] + streams = [] + for num, cut in enumerate(cuts): + # Each utterance has a Stream. + stream = Stream( + params=params, + cut_id=cut.id, + decoding_graph=decoding_graph, + device=device, + LOG_EPS=LOG_EPSILON, + ) + + stream.states = model.encoder.get_init_states(device=device) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + # The trained model is using normalized samples + assert audio.max() <= 1, "Should be normalized to [-1, 1])" + + samples = torch.from_numpy(audio).squeeze(0) + feature = fbank(samples) + stream.set_feature(feature) + stream.ground_truth = cut.supervisions[0].text + + streams.append(stream) + + while len(streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + model=model, + streams=streams, + params=params, + decoding_graph=decoding_graph, + ) + + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + streams[i].id, + streams[i].ground_truth.split(), + sp.decode(streams[i].decoding_result()).split(), + ) + ) + del streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + while len(streams) > 0: + finished_streams = decode_one_chunk( + model=model, + streams=streams, + params=params, + decoding_graph=decoding_graph, + ) + + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + streams[i].id, + streams[i].ground_truth.split(), + sp.decode(streams[i].decoding_result()).split(), + ) + ) + del streams[i] + + if params.decoding_method == "greedy_search": + key = "greedy_search" + elif params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ) + else: + key = f"beam_size_{params.beam_size}" + + return {key: decode_results} + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + store_transcripts(filename=recog_path, texts=sorted(results)) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "fast_beam_search", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + elif "beam_search" in params.decoding_method: + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-streaming-decode") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + params.device = device + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.eval() + + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_sets = ["test-clean", "test-other"] + test_cuts = [test_clean_cuts, test_other_cuts] + + for test_set, test_cut in zip(test_sets, test_cuts): + results_dict = decode_dataset( + cuts=test_cut, + model=model, + params=params, + sp=sp, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + torch.manual_seed(20220810) + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/train.py b/egs/librispeech/ASR/lstm_transducer_stateless/train.py old mode 100644 new mode 100755 index e69de29bb..d30fc260a --- a/egs/librispeech/ASR/lstm_transducer_stateless/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/train.py @@ -0,0 +1,1157 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./lstm_transducer_stateless/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir lstm_transducer_stateless/exp \ + --full-libri 1 \ + --max-duration 300 + +# For mix precision training: + +./lstm_transducer_stateless/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir lstm_transducer_stateless/exp \ + --full-libri 1 \ + --max-duration 550 +""" + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from lstm import RNN +from model import Transducer +from optim import Eden, Eve +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.utils import ( + AttributeDict, + MetricsTracker, + display_and_save_batch, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=int, + default=12, + help="Number of RNN encoder layers..", + ) + + parser.add_argument( + "--encoder-dim", + type=int, + default=512, + help="Encoder output dimesion.", + ) + + parser.add_argument( + "--rnn-hidden-size", + type=int, + default=1024, + help="Hidden dim for LSTM layers.", + ) + + parser.add_argument( + "--aux-layer-period", + type=int, + default=0, + help="""Peroid of auxiliary layers used for randomly combined during training. + If set to 0, will not use the random combiner (Default). + You can set a positive integer to use the random combiner, e.g., 3. + """, + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=35, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="lstm_transducer_stateless/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--initial-lr", + type=float, + default=0.003, + help="""The initial learning rate. This value should not need to be + changed.""", + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate decreases. + We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=10, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" + "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=20, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=100, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--delay-penalty", + type=float, + default=0.0, + help="""A constant value used to penalize symbol delay, + to encourage streaming models to emit symbols earlier. + See https://github.com/k2-fsa/k2/issues/955 and + https://arxiv.org/pdf/2211.00490.pdf for more details.""", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for conformer + "feature_dim": 80, + "subsampling_factor": 4, + "dim_feedforward": 2048, + # parameters for decoder + "decoder_dim": 512, + # parameters for joiner + "joiner_dim": 512, + # parameters for Noam + "model_warm_step": 3000, # arg given to model, not for lrate + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = RNN( + num_features=params.feature_dim, + subsampling_factor=params.subsampling_factor, + d_model=params.encoder_dim, + rnn_hidden_size=params.rnn_hidden_size, + dim_feedforward=params.dim_feedforward, + num_encoder_layers=params.num_encoder_layers, + aux_layer_period=params.aux_layer_period, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=params.encoder_dim, + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=params.encoder_dim, + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, + warmup: float = 1.0, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute RNN-T loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + warmup=warmup, + reduction="none", + delay_penalty=params.delay_penalty if warmup >= 2.0 else 0, + ) + simple_loss_is_finite = torch.isfinite(simple_loss) + pruned_loss_is_finite = torch.isfinite(pruned_loss) + is_finite = simple_loss_is_finite & pruned_loss_is_finite + if not torch.all(is_finite): + logging.info( + "Not all losses are finite!\n" + f"simple_loss: {simple_loss}\n" + f"pruned_loss: {pruned_loss}" + ) + display_and_save_batch(batch, params=params, sp=sp) + simple_loss = simple_loss[simple_loss_is_finite] + pruned_loss = pruned_loss[pruned_loss_is_finite] + + # If either all simple_loss or pruned_loss is inf or nan, + # we stop the training process by raising an exception + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): + raise ValueError( + "There are too many utterances in this batch " + "leading to inf or nan losses." + ) + + simple_loss = simple_loss.sum() + pruned_loss = pruned_loss.sum() + # after the main warmup step, we keep pruned_loss_scale small + # for the same amount of time (model_warm_step), to avoid + # overwhelming the simple_loss and causing it to diverge, + # in case it had not fully learned the alignment yet. + pruned_loss_scale = ( + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss + ) + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # info["frames"] is an approximate number for two reasons: + # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 + # (2) If some utterances in the batch lead to inf/nan loss, they + # are filtered out. + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) + + # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa + info["utterances"] = feature.size(0) + # averaged input duration in frames over utterances + info["utt_duration"] = feature_lens.sum().item() + # averaged padding proportion over utterances + info["utt_pad_proportion"] = ( + ((feature.size(1) - feature_lens) / feature.size(1)).sum().item() + ) + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + warmup=(params.batch_idx_train / params.model_warm_step), + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 30: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}" + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + if params.full_libri is False: + params.valid_interval = 800 + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank]) + + optimizer = Eve(model.parameters(), lr=params.initial_lr) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + # # overwrite it + # scheduler.base_lrs = [params.initial_lr for _ in scheduler.base_lrs] + # print(scheduler.base_lrs) + + if params.print_diagnostics: + diagnostic = diagnostics.attach_diagnostics(model) + + librispeech = LibriSpeechAsrDataModule(args) + + train_cuts = librispeech.train_clean_100_cuts() + if params.full_libri: + train_cuts += librispeech.train_clean_360_cuts() + train_cuts += librispeech.train_other_500_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./lstm.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 3) // 2 - 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + warmup=0.0 if params.start_epoch == 1 else 1.0, + ) + + scaler = GradScaler(enabled=params.use_fp16) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, + warmup: float, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + warmup=warmup, + ) + loss.backward() + optimizer.step() + optimizer.zero_grad() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index f7e1b5a54..bad4e243e 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -185,24 +185,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -299,7 +295,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -477,7 +474,9 @@ def decode_one_batch( ) feature_lens += num_tail_padded_frames - encoder_out, encoder_out_lens, _ = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens, _ = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] @@ -536,7 +535,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -698,7 +700,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -731,7 +735,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -784,7 +789,9 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -819,12 +826,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -852,12 +860,13 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -886,7 +895,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -952,7 +961,9 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py index 0ad00cda3..190673638 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py @@ -146,24 +146,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -229,7 +225,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) add_model_arguments(parser) @@ -345,7 +342,9 @@ def export_encoder_model_onnx( x = torch.zeros(N, 9, 80, dtype=torch.float32) x_lens = torch.tensor([9], dtype=torch.int64) h = torch.rand(encoder_model.num_encoder_layers, N, encoder_model.d_model) - c = torch.rand(encoder_model.num_encoder_layers, N, encoder_model.rnn_hidden_size) + c = torch.rand( + encoder_model.num_encoder_layers, N, encoder_model.rnn_hidden_size + ) warmup = 1.0 torch.onnx.export( @@ -446,9 +445,13 @@ def export_joiner_model_onnx( - projected_decoder_out: a tensor of shape (N, joiner_dim) """ - encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") + encoder_proj_filename = str(joiner_filename).replace( + ".onnx", "_encoder_proj.onnx" + ) - decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") + decoder_proj_filename = str(joiner_filename).replace( + ".onnx", "_decoder_proj.onnx" + ) encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] @@ -547,12 +550,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -581,12 +585,13 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -615,7 +620,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -689,7 +694,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py index 5a8efd718..da184b76f 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py @@ -86,12 +86,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -126,9 +124,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -316,7 +315,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/model.py b/egs/librispeech/ASR/lstm_transducer_stateless2/model.py index 4957d14b1..fadeb4ac2 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/model.py @@ -84,7 +84,9 @@ class Transducer(nn.Module): self.decoder_giga = decoder_giga self.joiner_giga = joiner_giga - self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_speed=0.5 + ) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) if decoder_giga is not None: @@ -188,7 +190,9 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py index 3b471fa85..410de8d3d 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py @@ -156,7 +156,9 @@ class Model: assert ret == 0, ret encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() - encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to(torch.int32) + encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to( + torch.int32 + ) hx = torch.from_numpy(ncnn_out2.numpy()).clone() cx = torch.from_numpy(ncnn_out3.numpy()).clone() return encoder_out, encoder_out_lens, hx, cx @@ -198,9 +200,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -283,7 +286,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py index 7d931a286..bef0ad760 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py @@ -92,11 +92,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -121,12 +119,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -173,7 +169,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -204,9 +201,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -269,11 +267,15 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens, _ = model.encoder(x=features, x_lens=feature_lengths) + encoder_out, encoder_out_lens, _ = model.encoder( + x=features, x_lens=feature_lengths + ) num_waves = encoder_out.size(0) hyps = [] @@ -345,7 +347,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py index baff15ea6..e47a05a9e 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py @@ -144,7 +144,9 @@ class Model: assert ret == 0, ret encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() - encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to(torch.int32) + encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to( + torch.int32 + ) hx = torch.from_numpy(ncnn_out2.numpy()).clone() cx = torch.from_numpy(ncnn_out3.numpy()).clone() return encoder_out, encoder_out_lens, hx, cx @@ -186,9 +188,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -226,7 +229,9 @@ def greedy_search( if decoder_out is None: assert hyp is None, hyp hyp = [blank_id] * context_size - decoder_input = torch.tensor(hyp, dtype=torch.int32) # (1, context_size) + decoder_input = torch.tensor( + hyp, dtype=torch.int32 + ) # (1, context_size) decoder_out = model.run_decoder(decoder_input).squeeze(0) else: assert decoder_out.ndim == 1 @@ -305,7 +310,9 @@ def main(): frames.append(online_fbank.get_frame(num_processed_frames + i)) num_processed_frames += offset frames = torch.cat(frames, dim=0) - encoder_out, encoder_out_lens, hx, cx = model.run_encoder(frames, states) + encoder_out, encoder_out_lens, hx, cx = model.run_encoder( + frames, states + ) states = (hx, cx) hyp, decoder_out = greedy_search( model, encoder_out.squeeze(0), decoder_out, hyp @@ -321,7 +328,9 @@ def main(): frames.append(online_fbank.get_frame(num_processed_frames + i)) num_processed_frames += offset frames = torch.cat(frames, dim=0) - encoder_out, encoder_out_lens, hx, cx = model.run_encoder(frames, states) + encoder_out, encoder_out_lens, hx, cx = model.run_encoder( + frames, states + ) states = (hx, cx) hyp, decoder_out = greedy_search( model, encoder_out.squeeze(0), decoder_out, hyp @@ -334,7 +343,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py index b31fefa0a..232d3dd18 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py @@ -109,12 +109,10 @@ def get_args(): parser.add_argument( "sound_filename", type=str, - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -149,9 +147,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -200,7 +199,9 @@ class Model: sess_options=self.session_opts, ) - def run_encoder(self, x, h0, c0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def run_encoder( + self, x, h0, c0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: x: @@ -257,7 +258,9 @@ class Model: }, )[0] - return self.run_joiner_decoder_proj(torch.from_numpy(decoder_out).squeeze(1)) + return self.run_joiner_decoder_proj( + torch.from_numpy(decoder_out).squeeze(1) + ) def run_joiner( self, @@ -300,7 +303,11 @@ class Model: projected_encoder_out = self.joiner_encoder_proj.run( [self.joiner_encoder_proj.get_outputs()[0].name], - {self.joiner_encoder_proj.get_inputs()[0].name: encoder_out.numpy()}, + { + self.joiner_encoder_proj.get_inputs()[ + 0 + ].name: encoder_out.numpy() + }, )[0] return torch.from_numpy(projected_encoder_out) @@ -319,7 +326,11 @@ class Model: projected_decoder_out = self.joiner_decoder_proj.run( [self.joiner_decoder_proj.get_outputs()[0].name], - {self.joiner_decoder_proj.get_inputs()[0].name: decoder_out.numpy()}, + { + self.joiner_decoder_proj.get_inputs()[ + 0 + ].name: decoder_out.numpy() + }, )[0] return torch.from_numpy(projected_decoder_out) @@ -358,7 +369,9 @@ def greedy_search( if decoder_out is None: assert hyp is None, hyp hyp = [blank_id] * context_size - decoder_input = torch.tensor([hyp], dtype=torch.int64) # (1, context_size) + decoder_input = torch.tensor( + [hyp], dtype=torch.int64 + ) # (1, context_size) decoder_out = model.run_decoder(decoder_input) else: assert decoder_out.shape[0] == 1 @@ -461,7 +474,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py index 08a895a75..5eaaf321f 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py @@ -95,7 +95,9 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def add_model_arguments(parser: argparse.ArgumentParser): @@ -161,7 +163,8 @@ def get_parser(): "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. " + "Otherwise, use 100h subset.", ) parser.add_argument( @@ -235,45 +238,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -645,7 +645,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -688,7 +692,9 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -701,9 +707,14 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -714,7 +725,9 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -945,7 +958,9 @@ def train_one_epoch( f"train/current_{prefix}_", params.batch_idx_train, ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) libri_tot_loss.write_summary( tb_writer, "train/libri_tot_", params.batch_idx_train ) @@ -991,7 +1006,8 @@ def filter_short_and_long_utterances( # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False @@ -1139,7 +1155,9 @@ def run(rank, world_size, args): train_giga_cuts = train_giga_cuts.repeat(times=None) if args.enable_musan: - cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") + cuts_musan = load_manifest( + Path(args.manifest_dir) / "musan_cuts.jsonl.gz" + ) else: cuts_musan = None diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py index a8d5605fb..9eee19379 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py @@ -182,24 +182,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -294,7 +290,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -389,7 +386,9 @@ def decode_one_batch( ) feature_lens += num_tail_padded_frames - encoder_out, encoder_out_lens, _ = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens, _ = model.encoder( + x=feature, x_lens=feature_lens + ) if params.decoding_method == "fast_beam_search": res = fast_beam_search_one_best( @@ -442,7 +441,10 @@ def decode_one_batch( nbest_scale=params.nbest_scale, return_timestamps=True, ) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): res = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -520,7 +522,9 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]: +) -> Dict[ + str, List[Tuple[str, List[str], List[str], List[float], List[float]]] +]: """Decode dataset. Args: @@ -595,7 +599,9 @@ def decode_dataset( cut_ids, hyps, texts, timestamps_hyp, timestamps_ref ): ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words, time_ref, time_hyp)) + this_batch.append( + (cut_id, ref_words, hyp_words, time_ref, time_hyp) + ) results[name].extend(this_batch) @@ -604,7 +610,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -642,7 +650,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -669,7 +678,9 @@ def save_results( note = "" logging.info(s) - s = "\nFor {}, symbol-delay of different settings are:\n".format(test_set_name) + s = "\nFor {}, symbol-delay of different settings are:\n".format( + test_set_name + ) note = "\tbest for {}".format(test_set_name) for key, val in test_set_delays: s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note) @@ -713,7 +724,9 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -745,12 +758,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -773,12 +787,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -806,7 +821,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -833,7 +848,9 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/export.py b/egs/librispeech/ASR/lstm_transducer_stateless3/export.py index 51238f768..212c7bad6 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/export.py @@ -122,24 +122,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -176,7 +172,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) add_model_arguments(parser) @@ -284,12 +281,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -312,12 +310,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -345,7 +344,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -381,7 +380,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py index 180ba8c72..a3443cf0a 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py @@ -85,12 +85,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -125,9 +123,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -315,7 +314,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py b/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py index 6e51b85e4..90bc351f4 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py @@ -661,7 +661,9 @@ class RandomCombine(nn.Module): self.stddev = stddev self.final_log_weight = ( - torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)) + torch.tensor( + (final_weight / (1 - final_weight)) * (self.num_inputs - 1) + ) .log() .item() ) @@ -758,14 +760,16 @@ class RandomCombine(nn.Module): # final contains self.num_inputs - 1 in all elements final = torch.full((num_frames,), self.num_inputs - 1, device=device) # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. # noqa - nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device) + nonfinal = torch.randint( + self.num_inputs - 1, (num_frames,), device=device + ) indexes = torch.where( torch.rand(num_frames, device=device) < final_prob, final, nonfinal ) - ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to( - dtype=dtype - ) + ans = torch.nn.functional.one_hot( + indexes, num_classes=self.num_inputs + ).to(dtype=dtype) return ans def _get_random_mixed_weights( diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py index 4f8049245..0e48fef04 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py @@ -89,11 +89,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -118,12 +116,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -170,7 +166,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -201,9 +198,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -266,11 +264,15 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens, _ = model.encoder(x=features, x_lens=feature_lengths) + encoder_out, encoder_out_lens, _ = model.encoder( + x=features, x_lens=feature_lengths + ) num_waves = encoder_out.size(0) hyps = [] @@ -342,7 +344,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py index 4e9063a40..cfa918ed5 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py @@ -101,9 +101,8 @@ def get_parser(): "--epoch", type=int, default=40, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( @@ -120,24 +119,20 @@ def get_parser(): "--avg", type=int, default=20, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -204,7 +199,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -363,7 +359,9 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner(current_encoder_out, decoder_out, project_input=False) + logits = model.joiner( + current_encoder_out, decoder_out, project_input=False + ) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -380,7 +378,9 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -539,7 +539,9 @@ def decode_one_chunk( feature_list, batch_first=True, padding_value=LOG_EPSILON ).to(device) feature_lens = torch.tensor(feature_len_list, device=device) - num_processed_frames = torch.tensor(num_processed_frames_list, device=device) + num_processed_frames = torch.tensor( + num_processed_frames_list, device=device + ) # Make sure it has at least 1 frame after subsampling tail_length = params.subsampling_factor + 5 @@ -581,7 +583,8 @@ def decode_one_chunk( with warnings.catch_warnings(): warnings.simplefilter("ignore") processed_lens = ( - num_processed_frames // params.subsampling_factor + encoder_out_lens + num_processed_frames // params.subsampling_factor + + encoder_out_lens ) fast_beam_search_one_best( model=model, @@ -593,7 +596,9 @@ def decode_one_chunk( max_states=params.max_states, ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) # Update cached states of each stream state_list = unstack_states(states) @@ -768,7 +773,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -810,7 +816,9 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -844,12 +852,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -872,12 +881,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -905,7 +915,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py index a1d19fb73..60a5a2be7 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py @@ -87,7 +87,9 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def add_model_arguments(parser: argparse.ArgumentParser): @@ -230,45 +232,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -607,7 +606,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -647,7 +650,9 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -660,9 +665,14 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -673,7 +683,9 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -840,7 +852,10 @@ def train_one_epoch( rank=rank, ) - if batch_idx % params.log_interval == 0 and not params.print_diagnostics: + if ( + batch_idx % params.log_interval == 0 + and not params.print_diagnostics + ): cur_lr = scheduler.get_last_lr()[0] logging.info( f"Epoch {params.cur_epoch}, " @@ -857,7 +872,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if ( batch_idx > 0 @@ -992,7 +1009,8 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py index fd2a5354a..8dd1459ca 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py +++ b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py @@ -74,18 +74,17 @@ class LibriSpeechAsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description=( - "These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc." - ), + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", ) group.add_argument( "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. " + "Otherwise, use 100h subset.", ) group.add_argument( "--manifest-dir", @@ -97,91 +96,75 @@ class LibriSpeechAsrDataModule: "--max-duration", type=int, default=200.0, - help=( - "Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM." - ), + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help=( - "When enabled, the batches will come from buckets of " - "similar duration (saves padding frames)." - ), + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", ) group.add_argument( "--num-buckets", type=int, default=30, - help=( - "The number of buckets for the BucketingSampler" - "(you might want to increase it for larger datasets)." - ), + help="The number of buckets for the BucketingSampler" + "(you might want to increase it for larger datasets).", ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help=( - "When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding." - ), + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", ) group.add_argument( "--duration-factor", type=float, default=1.0, - help=( - "Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch." - ), + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", ) group.add_argument( "--gap", type=float, default=1.0, - help=( - "The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used." - ), + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help=( - "When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available." - ), + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", ) group.add_argument( "--shuffle", type=str2bool, default=True, - help=( - "When enabled (=default), the examples will be shuffled for each epoch." - ), + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help=( - "When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it." - ), + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that collect the batches.", + help="The number of training dataloader workers that " + "collect the batches.", ) group.add_argument( @@ -195,22 +178,18 @@ class LibriSpeechAsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help=( - "Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp." - ), + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help=( - "When enabled, select noise from MUSAN and mix it" - "with training dataset. " - ), + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", ) def train_dataloaders( @@ -229,16 +208,20 @@ class LibriSpeechAsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.json.gz") + cuts_musan = load_manifest( + self.args.manifest_dir / "cuts_musan.json.gz" + ) transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix( + cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True + ) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - "Using cut concatenation with duration factor " + f"Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -253,7 +236,9 @@ class LibriSpeechAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + logging.info( + f"Time warp factor: {self.args.spec_aug_time_warp_factor}" + ) # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -296,7 +281,9 @@ class LibriSpeechAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -353,7 +340,9 @@ class LibriSpeechAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), return_cuts=self.args.return_cuts, ) else: @@ -400,17 +389,23 @@ class LibriSpeechAsrDataModule: @lru_cache() def train_clean_100_cuts(self) -> CutSet: logging.info("About to get train-clean-100 cuts") - return load_manifest(self.args.manifest_dir / "cuts_train-clean-100.json.gz") + return load_manifest( + self.args.manifest_dir / "cuts_train-clean-100.json.gz" + ) @lru_cache() def train_clean_360_cuts(self) -> CutSet: logging.info("About to get train-clean-360 cuts") - return load_manifest(self.args.manifest_dir / "cuts_train-clean-360.json.gz") + return load_manifest( + self.args.manifest_dir / "cuts_train-clean-360.json.gz" + ) @lru_cache() def train_other_500_cuts(self) -> CutSet: logging.info("About to get train-other-500 cuts") - return load_manifest(self.args.manifest_dir / "cuts_train-other-500.json.gz") + return load_manifest( + self.args.manifest_dir / "cuts_train-other-500.json.gz" + ) @lru_cache() def dev_clean_cuts(self) -> CutSet: diff --git a/egs/librispeech/ASR/pruned2_knowledge/beam_search.py b/egs/librispeech/ASR/pruned2_knowledge/beam_search.py index 785a8f097..2e9bf3e0b 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/beam_search.py +++ b/egs/librispeech/ASR/pruned2_knowledge/beam_search.py @@ -172,9 +172,9 @@ def greedy_search( y = logits.argmax().item() if y != blank_id: hyp.append(y) - decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( - 1, context_size - ) + decoder_input = torch.tensor( + [hyp[-context_size:]], device=device + ).reshape(1, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.joiner.decoder_proj(decoder_out) @@ -302,7 +302,9 @@ class HypothesisList(object): key = hyp.key if key in self: old_hyp = self._data[key] # shallow copy - torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) + torch.logaddexp( + old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob + ) else: self._data[key] = hyp @@ -318,7 +320,9 @@ class HypothesisList(object): Return the hypothesis that has the largest `log_prob`. """ if length_norm: - return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) + return max( + self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) + ) else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) @@ -492,7 +496,9 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) diff --git a/egs/librispeech/ASR/pruned2_knowledge/conformer.py b/egs/librispeech/ASR/pruned2_knowledge/conformer.py index 3b6d0549d..295a35204 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/conformer.py +++ b/egs/librispeech/ASR/pruned2_knowledge/conformer.py @@ -18,10 +18,10 @@ import math import warnings from typing import Optional, Tuple +from sampling import create_knowledge_base, KnowledgeBaseLookup import torch from encoder_interface import EncoderInterface -from sampling import KnowledgeBaseLookup, create_knowledge_base from scaling import ( ActivationBalancer, BasicNorm, @@ -73,9 +73,9 @@ class Conformer(EncoderInterface): if subsampling_factor != 4: raise NotImplementedError("Support only 'subsampling_factor=4'.") - self.knowledge_base = create_knowledge_base( - knowledge_M, knowledge_N, knowledge_D - ) + + self.knowledge_base = create_knowledge_base(knowledge_M, knowledge_N, + knowledge_D) # self.encoder_embed converts the input of shape (N, T, num_features) # to the shape (N, T//subsampling_factor, d_model). @@ -89,7 +89,7 @@ class Conformer(EncoderInterface): # Pass in a lambda that creates a new ConformerEncoderLayer with these # args. Don't use deepcopy because we need the knowledge_base # to be shared. - encoder_layer_fn = lambda: ConformerEncoderLayer( # noqa: E731 + encoder_layer_fn = lambda: ConformerEncoderLayer( self.knowledge_base, d_model, nhead, @@ -100,7 +100,7 @@ class Conformer(EncoderInterface): knowledge_M, knowledge_N, knowledge_D, - knowledge_K, + knowledge_K ) self.encoder = ConformerEncoder(encoder_layer_fn, num_encoder_layers) @@ -187,7 +187,9 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -207,14 +209,10 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.lookup = KnowledgeBaseLookup( - knowledge_M, - knowledge_N, - knowledge_D, - knowledge_K, - d_model, - knowledge_base, - ) + self.lookup = KnowledgeBaseLookup(knowledge_M, knowledge_N, + knowledge_D, knowledge_K, + d_model, + knowledge_base) self.norm_final = BasicNorm(d_model) @@ -313,7 +311,9 @@ class ConformerEncoder(nn.Module): def __init__(self, encoder_layer_fn, num_layers: int) -> None: super().__init__() - self.layers = nn.ModuleList([encoder_layer_fn() for i in range(num_layers)]) + self.layers = nn.ModuleList( + [encoder_layer_fn() for i in range(num_layers)] + ) self.num_layers = num_layers def forward( @@ -367,7 +367,9 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -382,7 +384,9 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vecotr and `j` means the @@ -657,9 +661,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -728,25 +732,33 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor" - " instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -783,7 +795,9 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -791,9 +805,13 @@ class RelPositionMultiheadAttention(nn.Module): ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) - attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) + attn_output_weights = ( + matrix_ac + matrix_bd + ) # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -827,9 +845,13 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -852,7 +874,9 @@ class ConvolutionModule(nn.Module): """ - def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: + def __init__( + self, channels: int, kernel_size: int, bias: bool = True + ) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/pruned2_knowledge/decode.py b/egs/librispeech/ASR/pruned2_knowledge/decode.py index 65da19f27..b4a9af55a 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/decode.py +++ b/egs/librispeech/ASR/pruned2_knowledge/decode.py @@ -76,7 +76,11 @@ from beam_search import ( ) from train import get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import ( AttributeDict, setup_logger, @@ -94,19 +98,16 @@ def get_parser(): "--epoch", type=int, default=28, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -185,7 +186,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -243,7 +245,9 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] if params.decoding_method == "fast_beam_search": @@ -258,7 +262,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -302,7 +309,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -374,7 +385,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -406,7 +419,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned2_knowledge/decoder.py b/egs/librispeech/ASR/pruned2_knowledge/decoder.py index 0b9c886c7..b6d94aaf1 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/decoder.py +++ b/egs/librispeech/ASR/pruned2_knowledge/decoder.py @@ -90,7 +90,9 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) + embedding_out = F.pad( + embedding_out, pad=(self.context_size - 1, 0) + ) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned2_knowledge/decoder2.py b/egs/librispeech/ASR/pruned2_knowledge/decoder2.py index 2ca76a30c..db51fb1cd 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/decoder2.py +++ b/egs/librispeech/ASR/pruned2_knowledge/decoder2.py @@ -14,13 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional - import torch import torch.nn as nn import torch.nn.functional as F -from subsampling import ScaledConv1d from torch import Tensor +from typing import Optional +from subsampling import ScaledConv1d class Decoder(nn.Module): @@ -91,7 +90,9 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) + embedding_out = F.pad( + embedding_out, pad=(self.context_size - 1, 0) + ) else: # During inference time, there is no need to do extra padding # as we only need one output @@ -101,6 +102,7 @@ class Decoder(nn.Module): return embedding_out + class ScaledEmbedding(nn.Module): r"""A simple lookup table that stores embeddings of a fixed dictionary and size. @@ -169,13 +171,8 @@ class ScaledEmbedding(nn.Module): [ 0.0000, 0.0000, 0.0000], [-0.1655, 0.9897, 0.0635]]]) """ - __constants__ = [ - "num_embeddings", - "embedding_dim", - "padding_idx", - "scale_grad_by_freq", - "sparse", - ] + __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', + 'scale_grad_by_freq', 'sparse'] num_embeddings: int embedding_dim: int @@ -184,41 +181,34 @@ class ScaledEmbedding(nn.Module): weight: Tensor sparse: bool - def __init__( - self, - num_embeddings: int, - embedding_dim: int, - padding_idx: Optional[int] = None, - scale_grad_by_freq: bool = False, - sparse: bool = False, - scale_speed: float = 5.0, - ) -> None: + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, + scale_grad_by_freq: bool = False, + sparse: bool = False, + scale_speed: float = 5.0) -> None: super(ScaledEmbedding, self).__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim if padding_idx is not None: if padding_idx > 0: - assert ( - padding_idx < self.num_embeddings - ), "Padding_idx must be within num_embeddings" + assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' elif padding_idx < 0: - assert ( - padding_idx >= -self.num_embeddings - ), "Padding_idx must be within num_embeddings" + assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' padding_idx = self.num_embeddings + padding_idx self.padding_idx = padding_idx self.scale_grad_by_freq = scale_grad_by_freq self.scale_speed = scale_speed - self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() + self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() self.sparse = sparse self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) self.reset_parameters() + + def reset_parameters(self) -> None: nn.init.normal_(self.weight, std=0.05) - nn.init.constant_(self.scale, torch.tensor(1.0 / 0.05).log() / self.scale_speed) + nn.init.constant_(self.scale, torch.tensor(1.0/0.05).log() / self.scale_speed) if self.padding_idx is not None: with torch.no_grad(): @@ -227,38 +217,22 @@ class ScaledEmbedding(nn.Module): def forward(self, input: Tensor) -> Tensor: scale = (self.scale * self.scale_speed).exp() if input.numel() < self.num_embeddings: - return ( - F.embedding( - input, - self.weight, - self.padding_idx, - None, - 2.0, # None, 2.0 relate to normalization - self.scale_grad_by_freq, - self.sparse, - ) - * scale - ) + return F.embedding( + input, self.weight, self.padding_idx, + None, 2.0, # None, 2.0 relate to normalization + self.scale_grad_by_freq, self.sparse) * scale else: return F.embedding( - input, - self.weight * scale, - self.padding_idx, - None, - 2.0, # None, 2.0 relates to normalization - self.scale_grad_by_freq, - self.sparse, - ) + input, self.weight * scale, self.padding_idx, + None, 2.0, # None, 2.0 relates to normalization + self.scale_grad_by_freq, self.sparse) def extra_repr(self) -> str: - s = ( - "{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}," - " scale={scale}" - ) + s = '{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}, scale={scale}' if self.padding_idx is not None: - s += ", padding_idx={padding_idx}" + s += ', padding_idx={padding_idx}' if self.scale_grad_by_freq is not False: - s += ", scale_grad_by_freq={scale_grad_by_freq}" + s += ', scale_grad_by_freq={scale_grad_by_freq}' if self.sparse is not False: - s += ", sparse=True" + s += ', sparse=True' return s.format(**self.__dict__) diff --git a/egs/librispeech/ASR/pruned2_knowledge/export.py b/egs/librispeech/ASR/pruned2_knowledge/export.py index 1af05d9c8..96d1a30fb 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/export.py +++ b/egs/librispeech/ASR/pruned2_knowledge/export.py @@ -64,20 +64,17 @@ def get_parser(): "--epoch", type=int, default=28, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -108,7 +105,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) return parser @@ -176,7 +174,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned2_knowledge/joiner.py b/egs/librispeech/ASR/pruned2_knowledge/joiner.py index 68c663b66..35f75ed2a 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/joiner.py +++ b/egs/librispeech/ASR/pruned2_knowledge/joiner.py @@ -56,7 +56,9 @@ class Joiner(nn.Module): assert encoder_out.shape[:-1] == decoder_out.shape[:-1] if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) + logit = self.encoder_proj(encoder_out) + self.decoder_proj( + decoder_out + ) else: logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/pruned2_knowledge/model.py b/egs/librispeech/ASR/pruned2_knowledge/model.py index ca8c28af1..599bf2506 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/model.py +++ b/egs/librispeech/ASR/pruned2_knowledge/model.py @@ -63,7 +63,9 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_speed=0.5 + ) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) def forward( @@ -134,7 +136,9 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned2_knowledge/optim.py b/egs/librispeech/ASR/pruned2_knowledge/optim.py index 76cd4e11e..432bf8220 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/optim.py +++ b/egs/librispeech/ASR/pruned2_knowledge/optim.py @@ -72,11 +72,17 @@ class Eve(Optimizer): if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + raise ValueError( + "Invalid beta parameter at index 0: {}".format(betas[0]) + ) if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + raise ValueError( + "Invalid beta parameter at index 1: {}".format(betas[1]) + ) if not 0 <= weight_decay <= 0.1: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) if not 0 < target_rms <= 10.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) defaults = dict( @@ -112,7 +118,9 @@ class Eve(Optimizer): # Perform optimization step grad = p.grad if grad.is_sparse: - raise RuntimeError("AdamW does not support sparse gradients") + raise RuntimeError( + "AdamW does not support sparse gradients" + ) state = self.state[p] @@ -139,7 +147,7 @@ class Eve(Optimizer): # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( + denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( group["eps"] ) @@ -150,7 +158,9 @@ class Eve(Optimizer): if p.numel() > 1: # avoid applying this weight-decay on "scaling factors" # (which are scalar). - is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) + is_above_target_rms = p.norm() > ( + target_rms * (p.numel() ** 0.5) + ) p.mul_(1 - (weight_decay * is_above_target_rms)) p.addcdiv_(exp_avg, denom, value=-step_size) @@ -166,14 +176,18 @@ class LRScheduler(object): def __init__(self, optimizer: Optimizer, verbose: bool = False): # Attach optimizer if not isinstance(optimizer, Optimizer): - raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) + raise TypeError( + "{} is not an Optimizer".format(type(optimizer).__name__) + ) self.optimizer = optimizer self.verbose = verbose for group in optimizer.param_groups: group.setdefault("initial_lr", group["lr"]) - self.base_lrs = [group["initial_lr"] for group in optimizer.param_groups] + self.base_lrs = [ + group["initial_lr"] for group in optimizer.param_groups + ] self.epoch = 0 self.batch = 0 @@ -281,9 +295,10 @@ class Eden(LRScheduler): def get_lr(self): factor = ( - (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 + (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 ) ** -0.25 * ( - ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 + ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) + ** -0.25 ) return [x * factor for x in self.base_lrs] diff --git a/egs/librispeech/ASR/pruned2_knowledge/sampling.py b/egs/librispeech/ASR/pruned2_knowledge/sampling.py index 8cc930927..7b05e2f00 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/sampling.py +++ b/egs/librispeech/ASR/pruned2_knowledge/sampling.py @@ -3,29 +3,32 @@ # This was copied from /ceph-dan/torch-sampling/torch_sampling/sampling_ref.py, # its git history is there. -import random import timeit -from typing import Optional, Tuple - import torch +from torch import Tensor +from torch import nn +from torch.cuda.amp import GradScaler, custom_fwd, custom_bwd +from typing import Tuple, Optional from scaling import ScaledLinear -from torch import Tensor, nn -from torch.cuda.amp import GradScaler, custom_bwd, custom_fwd +import random from torch_scheduled_sampling import sample_combined # The main exports of this file are the module KnowledgeBaseLookup and the # function create_knowledge_base. + + + + def create_knowledge_base(M: int, N: int, D: int) -> nn.Parameter: std = 0.1 - a = (3**0.5) * std # this sqrt(3) thing is intended to get variance of - # 0.1 from uniform distribution - ans = nn.Parameter(torch.ones(M**N, D)) + a = (3 ** 0.5) * std # this sqrt(3) thing is intended to get variance of + # 0.1 from uniform distribution + ans = nn.Parameter(torch.ones(M ** N, D)) nn.init.uniform_(ans, -a, a) return ans - def join_indexes(indexes: Tensor, M: int) -> Tensor: """ Combines N-tuples of indexes into single indexes that can be used for @@ -44,9 +47,9 @@ def join_indexes(indexes: Tensor, M: int) -> Tensor: # Note, we don't use this, we -def weighted_matrix_lookup( - weights: Tensor, indexes: Tensor, knowledge_base: Tensor -) -> Tensor: +def weighted_matrix_lookup(weights: Tensor, + indexes: Tensor, + knowledge_base: Tensor) -> Tensor: """ Weighted combination of specified rows of a matrix. weights: Tensor of shape (*, K), can contain any value but probably in [0..1]. @@ -62,9 +65,9 @@ def weighted_matrix_lookup( # simpler but less memory-efficient implementation lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten()) D = knowledge_base.shape[-1] - weights = weights.unsqueeze(-2) # (*, 1, K) - lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) - ans = torch.matmul(weights, lookup) # ans: (*, 1, D) + weights = weights.unsqueeze(-2) # (*, 1, K) + lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) + ans = torch.matmul(weights, lookup) # ans: (*, 1, D) ans = ans.squeeze(-2) assert list(ans.shape) == list(weights.shape[:-2]) + [D] return ans @@ -73,9 +76,7 @@ def weighted_matrix_lookup( class WeightedMatrixLookupFunction(torch.autograd.Function): @staticmethod @custom_fwd - def forward( - ctx, weights: Tensor, indexes: Tensor, knowledge_base: Tensor - ) -> Tensor: + def forward(ctx, weights: Tensor, indexes: Tensor, knowledge_base: Tensor) -> Tensor: """ Weighted combination of specified rows of a matrix. weights: Tensor of shape (*, K), can contain any value but probably in [0..1]. @@ -87,16 +88,15 @@ class WeightedMatrixLookupFunction(torch.autograd.Function): """ if random.random() < 0.001: print("dtype[1] = ", weights.dtype) - ctx.save_for_backward( - weights.detach(), indexes.detach(), knowledge_base.detach() - ) + ctx.save_for_backward(weights.detach(), indexes.detach(), + knowledge_base.detach()) with torch.no_grad(): lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten()) D = knowledge_base.shape[-1] - weights = weights.unsqueeze(-2) # (*, 1, K) - lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) - ans = torch.matmul(weights, lookup) # ans: (*, 1, D) - ans = ans.squeeze(-2) # (*, D) + weights = weights.unsqueeze(-2) # (*, 1, K) + lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) + ans = torch.matmul(weights, lookup) # ans: (*, 1, D) + ans = ans.squeeze(-2) #(*, D) return ans @staticmethod @@ -107,7 +107,7 @@ class WeightedMatrixLookupFunction(torch.autograd.Function): knowledge_base.requires_grad = True dtype = ans_grad.dtype ans_grad = ans_grad.to(weights.dtype) - assert weights.requires_grad is False + assert weights.requires_grad == False D = knowledge_base.shape[-1] with torch.enable_grad(): # we'll use torch's autograd to differentiate this operation, which @@ -115,19 +115,16 @@ class WeightedMatrixLookupFunction(torch.autograd.Function): # We don't save `lookup` because it's large, that is the reason # we override Torch autograd. lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten()) - lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) - weights = weights.unsqueeze(-1) # (*, K, 1) + lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) + weights = weights.unsqueeze(-1) # (*, K, 1) # forward pass: was: ## ans = torch.matmul(weights, lookup) ## ans: (*, 1, D) ## ans = ans.squeeze(-2) # ans, ans_grad: (*, D) - weights_grad = torch.matmul( - lookup, ans_grad.unsqueeze(-1) # (*, K, D) - ) # (*, D, 1) - weights_grad = weights_grad.squeeze(-1) # (*, K, 1) -> (*, K) - lookup_grad = weights * ans_grad.unsqueeze( - -2 - ) # (*, K, 1) * (*, 1, D) = (*, K, D) + weights_grad = torch.matmul(lookup, # (*, K, D) + ans_grad.unsqueeze(-1)) # (*, D, 1) + weights_grad = weights_grad.squeeze(-1) # (*, K, 1) -> (*, K) + lookup_grad = weights * ans_grad.unsqueeze(-2) # (*, K, 1) * (*, 1, D) = (*, K, D) lookup.backward(gradient=lookup_grad) return weights_grad.to(dtype), None, knowledge_base.grad.to(dtype) @@ -149,7 +146,6 @@ class PenalizeNegentropyFunction(torch.autograd.Function): Returns: logprobs """ - @staticmethod def forward(ctx, logprobs: Tensor, alpha: float): ctx.save_for_backward(logprobs.detach()) @@ -158,23 +154,18 @@ class PenalizeNegentropyFunction(torch.autograd.Function): @staticmethod def backward(ctx, logprobs_grad: Tensor) -> Tuple[Tensor, None]: - (logprobs,) = ctx.saved_tensors + logprobs, = ctx.saved_tensors with torch.enable_grad(): logprobs.requires_grad = True # `negentropy` is the negative entropy of the average distribution. # distributions. It will be <= 0. - l = logprobs.reshape(-1, logprobs.shape[-1]) # noqa: E741 + l = logprobs.reshape(-1, logprobs.shape[-1]) scale = ctx.alpha * l.shape[0] avg_dist = l.exp().mean(dim=0) negentropy = (avg_dist * (avg_dist + 1.0e-20).log()).sum() if random.random() < 0.0005: negentropy_individual = (l * l.exp()).sum(dim=-1).mean() - print( - "Negentropy[individual,combined] = ", - negentropy_individual.item(), - ", ", - negentropy.item(), - ) + print("Negentropy[individual,combined] = ", negentropy_individual.item(), ", ", negentropy.item()) loss = negentropy * scale loss.backward() return logprobs_grad + logprobs.grad, None @@ -192,23 +183,18 @@ class KnowledgeBaseLookup(nn.Module): embedding_dim: the dimension to project from and to, e.g. the d_model of the conformer. """ - - def __init__( - self, - M: int, - N: int, - D: int, - K: int, - embedding_dim: int, - knowledge_base: nn.Parameter, - negentropy_penalty: float = 0.001, - ): + def __init__(self, M: int, N: int, D: int, + K: int, embedding_dim: int, + knowledge_base: nn.Parameter, + negentropy_penalty: float = 0.001): super(KnowledgeBaseLookup, self).__init__() self.knowledge_base = knowledge_base # shared! - self.in_proj = ScaledLinear(embedding_dim, M * N, initial_scale=1.0) + self.in_proj = ScaledLinear(embedding_dim, M * N, + initial_scale=1.0) # initial_scale = 4.0 because the knowlege_base activations are # quite small -- if we use our optimizer they'll have stddev <= 0.1. - self.out_proj = ScaledLinear(D, embedding_dim, initial_scale=4.0) + self.out_proj = ScaledLinear(D, embedding_dim, + initial_scale = 4.0) self.M = M self.N = N self.K = K @@ -224,14 +210,14 @@ class KnowledgeBaseLookup(nn.Module): # TODO: later we can try multiplying by a projection of x or something like that. """ - x = self.in_proj(x) # now (*, M*N) - x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M) - x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M) + x = self.in_proj(x) # now (*, M*N) + x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M) + x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M) x = PenalizeNegentropyFunction.apply(x, self.negentropy_penalty) _, indexes, weights = sample_combined(x, self.K, input_is_log=True) - x = weighted_matrix_lookup(weights, indexes, self.knowledge_base) # now (*, D) - x = self.out_proj(x) # now (*, self.embedding_dim) + x = weighted_matrix_lookup(weights, indexes, self.knowledge_base) # now (*, D) + x = self.out_proj(x) # now (*, self.embedding_dim) return x @@ -251,44 +237,38 @@ def _test_knowledge_base_lookup(): x.requires_grad = True y = m(x) assert y.shape == x.shape - y.sum().backward() # make sure backward doesn't crash.. + y.sum().backward() # make sure backward doesn't crash.. print("y = ", y) print("x.grad = ", x.grad) print("knowlege_base.grad norm = ", knowledge_base.grad.norm()) dtype = torch.float32 - device = torch.device("cuda") - train_pairs = [ - ( - torch.randn(B, T, E, device=device, dtype=dtype), - torch.randn(B, T, E, device=device, dtype=dtype), - ) - for _ in range(10) - ] + device = torch.device('cuda') + train_pairs = [ (torch.randn(B, T, E, device=device, dtype=dtype), torch.randn(B, T, E, device=device, dtype=dtype)) for _ in range(10) ] from optim import Eve - optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04) m = m.to(device).to(dtype) + start = timeit.default_timer() - # Epoch 0, batch 0, loss 1.0109944343566895 - # Epoch 10, batch 0, loss 1.0146660804748535 - # Epoch 20, batch 0, loss 1.0119813680648804 - # Epoch 30, batch 0, loss 1.0105408430099487 - # Epoch 40, batch 0, loss 1.0077732801437378 - # Epoch 50, batch 0, loss 1.0050103664398193 - # Epoch 60, batch 0, loss 1.0033129453659058 - # Epoch 70, batch 0, loss 1.0014232397079468 - # Epoch 80, batch 0, loss 0.9977912306785583 - # Epoch 90, batch 0, loss 0.8274348974227905 - # Epoch 100, batch 0, loss 0.3368612825870514 - # Epoch 110, batch 0, loss 0.11323091387748718 - # Time taken: 17.591704960912466 +# Epoch 0, batch 0, loss 1.0109944343566895 +# Epoch 10, batch 0, loss 1.0146660804748535 +# Epoch 20, batch 0, loss 1.0119813680648804 +# Epoch 30, batch 0, loss 1.0105408430099487 +# Epoch 40, batch 0, loss 1.0077732801437378 +# Epoch 50, batch 0, loss 1.0050103664398193 +# Epoch 60, batch 0, loss 1.0033129453659058 +# Epoch 70, batch 0, loss 1.0014232397079468 +# Epoch 80, batch 0, loss 0.9977912306785583 +# Epoch 90, batch 0, loss 0.8274348974227905 +# Epoch 100, batch 0, loss 0.3368612825870514 +# Epoch 110, batch 0, loss 0.11323091387748718 +# Time taken: 17.591704960912466 for epoch in range(150): - for n, (x, y) in enumerate(train_pairs): + for n, (x,y) in enumerate(train_pairs): y_out = m(x) - loss = ((y_out - y) ** 2).mean() * 100.0 + loss = ((y_out - y)**2).mean() * 100.0 if n % 10 == 0 and epoch % 10 == 0: print(f"Epoch {epoch}, batch {n}, loss {loss.item()}") loss.backward() @@ -296,8 +276,7 @@ def _test_knowledge_base_lookup(): optimizer.zero_grad() stop = timeit.default_timer() - print("Time taken: ", stop - start) - + print('Time taken: ', stop - start) def _test_knowledge_base_lookup_autocast(): K = 16 @@ -315,21 +294,14 @@ def _test_knowledge_base_lookup_autocast(): x.requires_grad = True y = m(x) assert y.shape == x.shape - y.sum().backward() # make sure backward doesn't crash.. + y.sum().backward() # make sure backward doesn't crash.. print("y = ", y) print("x.grad = ", x.grad) print("knowlege_base.grad norm = ", knowledge_base.grad.norm()) - device = torch.device("cuda") - train_pairs = [ - ( - torch.randn(B, T, E, device=device), - torch.randn(B, T, E, device=device), - ) - for _ in range(10) - ] + device = torch.device('cuda') + train_pairs = [ (torch.randn(B, T, E, device=device), torch.randn(B, T, E, device=device)) for _ in range(10) ] from optim import Eve - optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04) m = m.to(device) @@ -337,11 +309,12 @@ def _test_knowledge_base_lookup_autocast(): start = timeit.default_timer() + for epoch in range(150): - for n, (x, y) in enumerate(train_pairs): + for n, (x,y) in enumerate(train_pairs): y_out = m(x) with torch.cuda.amp.autocast(enabled=True): - loss = ((y_out - y) ** 2).mean() * 100.0 + loss = ((y_out - y)**2).mean() * 100.0 if n % 10 == 0 and epoch % 10 == 0: print(f"Epoch {epoch}, batch {n}, loss {loss.item()}") scaler.scale(loss).backward() @@ -350,9 +323,10 @@ def _test_knowledge_base_lookup_autocast(): optimizer.zero_grad() stop = timeit.default_timer() - print("Time taken: ", stop - start) + print('Time taken: ', stop - start) -if __name__ == "__main__": + +if __name__ == '__main__': _test_knowledge_base_lookup() _test_knowledge_base_lookup_autocast() diff --git a/egs/librispeech/ASR/pruned2_knowledge/scaling.py b/egs/librispeech/ASR/pruned2_knowledge/scaling.py index 527c735eb..f726c2583 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/scaling.py +++ b/egs/librispeech/ASR/pruned2_knowledge/scaling.py @@ -18,11 +18,11 @@ import collections from itertools import repeat from typing import Optional, Tuple +from torch.cuda.amp import custom_fwd, custom_bwd import torch import torch.nn as nn from torch import Tensor -from torch.cuda.amp import custom_bwd, custom_fwd def _ntuple(n): @@ -79,7 +79,9 @@ class ActivationBalancerFunction(torch.autograd.Function): below_threshold = mean_abs < min_abs above_threshold = mean_abs > max_abs - ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) + ctx.save_for_backward( + factor, xgt0, below_threshold, above_threshold + ) ctx.max_factor = max_factor ctx.sum_dims = sum_dims return x @@ -147,7 +149,8 @@ class BasicNorm(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels scales = ( - torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps.exp() + torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + + self.eps.exp() ) ** -0.5 return x * scales @@ -179,7 +182,11 @@ class ScaledLinear(nn.Linear): """ def __init__( - self, *args, initial_scale: float = 1.0, initial_speed: float = 1.0, **kwargs + self, + *args, + initial_scale: float = 1.0, + initial_speed: float = 1.0, + **kwargs ): super(ScaledLinear, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() @@ -195,12 +202,12 @@ class ScaledLinear(nn.Linear): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3**0.5) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) + scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -211,13 +218,19 @@ class ScaledLinear(nn.Linear): return None if self.bias is None else self.bias * self.bias_scale.exp() def forward(self, input: Tensor) -> Tensor: - return torch.nn.functional.linear(input, self.get_weight(), self.get_bias()) + return torch.nn.functional.linear( + input, self.get_weight(), self.get_bias() + ) class ScaledConv1d(nn.Conv1d): # See docs for ScaledLinear def __init__( - self, *args, initial_scale: float = 1.0, initial_speed: float = 1.0, **kwargs + self, + *args, + initial_scale: float = 1.0, + initial_speed: float = 1.0, + **kwargs ): super(ScaledConv1d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() @@ -232,12 +245,12 @@ class ScaledConv1d(nn.Conv1d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3**0.5) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) + scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -277,7 +290,11 @@ class ScaledConv1d(nn.Conv1d): class ScaledConv2d(nn.Conv2d): # See docs for ScaledLinear def __init__( - self, *args, initial_scale: float = 1.0, initial_speed: float = 1.0, **kwargs + self, + *args, + initial_scale: float = 1.0, + initial_speed: float = 1.0, + **kwargs ): super(ScaledConv2d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() @@ -292,12 +309,12 @@ class ScaledConv2d(nn.Conv2d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3**0.5) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) + scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -636,7 +653,9 @@ def _test_activation_balancer_sign(): def _test_activation_balancer_magnitude(): magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( + -1 + ) x = x.detach() x.requires_grad = True m = ActivationBalancer( @@ -666,8 +685,8 @@ def _test_basic_norm(): y = m(x) assert y.shape == x.shape - x_rms = (x**2).mean().sqrt() - y_rms = (y**2).mean().sqrt() + x_rms = (x ** 2).mean().sqrt() + y_rms = (y ** 2).mean().sqrt() print("x rms = ", x_rms) print("y rms = ", y_rms) assert y_rms < x_rms diff --git a/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py b/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py index 3f21133a0..6293e081a 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py +++ b/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py @@ -15,23 +15,21 @@ # limitations under the License. -from typing import Optional, Tuple - import torch import torch.nn as nn from torch import Tensor +from typing import Tuple, Optional -def _activation_balancer_loss( - mean_pos: Tensor, - mean_neg: Tensor, - min_positive: float, # e.g. 0.05 - max_positive: float, # e.g. 0.95 - max_factor: float, # e.g. 0.01 - min_abs: float, # e.g. 0.2 - max_abs: float, # e.g. 100.0 - eps: float = 1.0e-10, -): + +def _activation_balancer_loss(mean_pos: Tensor, + mean_neg: Tensor, + min_positive: float, # e.g. 0.05 + max_positive: float, # e.g. 0.95 + max_factor: float, # e.g. 0.01 + min_abs: float, # e.g. 0.2 + max_abs: float, # e.g. 100.0 + eps: float = 1.0e-10): """ Returns a loss-function for the ActivationBalancer module. This loss function is not exposed to the user but is used internally, and eventually @@ -52,32 +50,28 @@ def _activation_balancer_loss( """ loss_parts = [] - x_mean = mean_pos - mean_neg - x_mean_abs = (mean_pos + mean_neg + eps).detach() - x_rel_mean = x_mean / x_mean_abs + x_mean = mean_positive - mean_negative + x_mean_abs = (mean_positive + mean_negative + eps).detach() + x_rel_mean= x_mean / x_mean_abs if min_positive != 0.0: # e.g. x_mean_floor = -0.95 + 0.05 = -0.9 - x_rel_mean_floor = -(1 - min_positive) + min_positive - min_positive_loss = (x_rel_mean_floor - x_rel_mean).relu().sum() * ( - 1.0 / (2 * min_positive) - ) + x_rel_mean_floor = (-(1-min_positive) + min_positive) + min_positive_loss = (x_rel_mean_floor - x_rel_mean).relu().sum() * (1.0 / (2*min_positive)) # this part of the loss would be 1.0 * num_channels if all these constraints were # 100% violated. loss_parts.append(min_positive_loss) if max_positive != 1.0: # e.g. x_mean_floor = -0.05 + 0.95 = 0.8 - x_rel_mean_ceil = -(1.0 - max_positive) + max_positive - max_positive_loss = (x_rel_mean - x_rel_mean_ceil).relu().sum() * ( - 1.0 / (1 - x_rel_mean_ceil) - ) + x_rel_mean_ceil = - (1.0-max_positive) + max_positive + max_positive_loss = (x_rel_mean - x_rel_mean_ceil).relu().sum() * (1.0 / (1 - x_rel_mean_ceil)) # this part of the loss would be 1.0 * num_channels if all these constraints were # 100% violated. loss_parts.append(max_positive_loss) if min_abs != 0.0: - min_abs_loss = (min_abs - x_mean_abs).relu().sum() / min_abs + min_abs_loss = min_abs - x_mean_abs).relu().sum() / min_abs # this part of the loss would be 1.0 * num_channels if all these constraints were # 100% violated. loss_parts.append(min_abs_loss) @@ -88,53 +82,43 @@ def _activation_balancer_loss( # 100% violated. loss_parts.append(max_abs_loss) + # the min_positive and 1 - max_positive are "ballast" added to the denom = mean_pos + mean_neg + (min_positive + (1 - max_positive)) - # num + num if min_positive != 0.0: - pass + + class ActivationBalancerFunction(torch.autograd.Function): @staticmethod - def forward( - ctx, - x: Tensor, - channel_dim: int, - min_positive: float, # e.g. 0.05 - max_positive: float, # e.g. 0.95 - max_factor: float, # e.g. 0.01 - min_abs: float, # e.g. 0.2 - max_abs: float, # e.g. 100.0 + def forward(ctx, x: Tensor, + channel_dim: int, + min_positive: float, # e.g. 0.05 + max_positive: float, # e.g. 0.95 + max_factor: float, # e.g. 0.01 + min_abs: float, # e.g. 0.2 + max_abs: float, # e.g. 100.0 ) -> Tensor: if x.requires_grad: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] xgt0 = x > 0 - proportion_positive = torch.mean( - xgt0.to(x.dtype), dim=sum_dims, keepdim=True - ) - factor1 = ( - (min_positive - proportion_positive).relu() - * (max_factor / min_positive) - if min_positive != 0.0 - else 0.0 - ) - factor2 = ( - (proportion_positive - max_positive).relu() - * (max_factor / (max_positive - 1.0)) - if max_positive != 1.0 - else 0.0 - ) + proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True) + factor1 = ((min_positive - proportion_positive).relu() * (max_factor / min_positive) + if min_positive != 0.0 else 0.0) + factor2 = ((proportion_positive - max_positive).relu() * (max_factor / (max_positive - 1.0)) + if max_positive != 1.0 else 0.0) factor = factor1 + factor2 if isinstance(factor, float): factor = torch.zeros_like(proportion_positive) mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True) - below_threshold = mean_abs < min_abs - above_threshold = mean_abs > max_abs + below_threshold = (mean_abs < min_abs) + above_threshold = (mean_abs > max_abs) ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) ctx.max_factor = max_factor @@ -142,16 +126,11 @@ class ActivationBalancerFunction(torch.autograd.Function): return x @staticmethod - def backward( - ctx, x_grad: Tensor - ) -> Tuple[Tensor, None, None, None, None, None, None]: + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None, None]: factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors dtype = x_grad.dtype - scale_factor = ( - (below_threshold.to(dtype) - above_threshold.to(dtype)) - * (xgt0.to(dtype) - 0.5) - * (ctx.max_factor * 2.0) - ) + scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) * + (xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0)) neg_delta_grad = x_grad.abs() * (factor + scale_factor) return x_grad - neg_delta_grad, None, None, None, None, None, None @@ -184,30 +163,29 @@ class BasicNorm(torch.nn.Module): learn_eps: if true, we learn epsilon; if false, we keep it at the initial value. """ - - def __init__( - self, - num_channels: int, - channel_dim: int = -1, # CAUTION: see documentation. - eps: float = 0.25, - learn_eps: bool = True, - ) -> None: + def __init__(self, + num_channels: int, + channel_dim: int = -1, # CAUTION: see documentation. + eps: float = 0.25, + learn_eps: bool = True) -> None: super(BasicNorm, self).__init__() self.num_channels = num_channels self.channel_dim = channel_dim if learn_eps: self.eps = nn.Parameter(torch.tensor(eps).log().detach()) else: - self.register_buffer("eps", torch.tensor(eps).log().detach()) + self.register_buffer('eps', torch.tensor(eps).log().detach()) + def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels - scales = ( - torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps.exp() - ) ** -0.5 + scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) + + self.eps.exp()) ** -0.5 return x * scales + + class ScaledLinear(nn.Linear): """ A modified version of nn.Linear where the parameters are scaled before @@ -229,26 +207,27 @@ class ScaledLinear(nn.Linear): inherited from nn.Linear. For modules with small fan-in, this may be larger than optimal. """ - - def __init__(self, *args, initial_scale: float = 1.0, **kwargs): + def __init__(self, *args, + initial_scale: float = 1.0, + **kwargs): super(ScaledLinear, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() self.weight_scale = nn.Parameter(initial_scale.clone().detach()) if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: - self.register_parameter("bias_scale", None) + self.register_parameter('bias_scale', None) self._reset_parameters() # Overrides the reset_parameters in nn.Linear def _reset_parameters(self): std = 0.01 - a = (3**0.5) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) + scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() if self.bias is not None: @@ -258,67 +237,56 @@ class ScaledLinear(nn.Linear): return self.weight * self.weight_scale.exp() def get_bias(self): - return None if self.bias is None else self.bias * self.bias_scale.exp() + return (None if self.bias is None else + self.bias * self.bias_scale.exp()) def forward(self, input: Tensor) -> Tensor: - return torch.nn.functional.linear(input, self.get_weight(), self.get_bias()) + return torch.nn.functional.linear(input, self.get_weight(), + self.get_bias()) class ScaledConv1d(nn.Conv1d): - def __init__(self, *args, initial_scale=1.0, **kwargs): + def __init__(self, *args, + initial_scale=1.0, **kwargs): super(ScaledConv1d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() self.weight_scale = nn.Parameter(initial_scale.clone().detach()) if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: - self.register_parameter("bias_scale", None) + self.register_parameter('bias_scale', None) self._reset_parameters() # Overrides the reset_parameters in base class def _reset_parameters(self): std = 0.01 - a = (3**0.5) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) + scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() if self.bias is not None: self.bias_scale += torch.tensor(scale / std).log() + def get_weight(self): return self.weight * self.weight_scale.exp() def get_bias(self): - return None if self.bias is None else self.bias * self.bias_scale.exp() + return (None if self.bias is None else + self.bias * self.bias_scale.exp()) def forward(self, input: Tensor) -> Tensor: F = torch.nn.functional - if self.padding_mode != "zeros": - return F.conv1d( - F.pad( - input, - self._reversed_padding_repeated_twice, - mode=self.padding_mode, - ), - self.get_weight(), - self.get_bias(), - self.stride, - _single(0), # noqa: F821 - self.dilation, - self.groups, - ) - return F.conv1d( - input, - self.get_weight(), - self.get_bias(), - self.stride, - self.padding, - self.dilation, - self.groups, - ) + if self.padding_mode != 'zeros': + return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), + self.get_weight(), self.get_bias(), self.stride, + _single(0), self.dilation, self.groups) + return F.conv1d(input, self.get_weight(), self.get_bias(), self.stride, + self.padding, self.dilation, self.groups) + class ScaledConv2d(nn.Conv2d): @@ -329,58 +297,45 @@ class ScaledConv2d(nn.Conv2d): if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: - self.register_parameter("bias_scale", None) + self.register_parameter('bias_scale', None) self._reset_parameters() # Overrides the reset_parameters in base class def _reset_parameters(self): std = 0.01 - a = (3**0.5) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) + scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() if self.bias is not None: self.bias_scale += torch.tensor(scale / std).log() + def get_weight(self): return self.weight * self.weight_scale.exp() def get_bias(self): - return None if self.bias is None else self.bias * self.bias_scale.exp() + return (None if self.bias is None else + self.bias * self.bias_scale.exp()) def _conv_forward(self, input, weight): F = torch.nn.functional - if self.padding_mode != "zeros": - return F.conv2d( - F.pad( - input, - self._reversed_padding_repeated_twice, - mode=self.padding_mode, - ), - weight, - self.get_bias(), - self.stride, - _pair(0), # noqa: F821 - self.dilation, - self.groups, - ) - return F.conv2d( - input, - weight, - self.get_bias(), - self.stride, - self.padding, - self.dilation, - self.groups, - ) + if self.padding_mode != 'zeros': + return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), + weight, self.get_bias(), self.stride, + _pair(0), self.dilation, self.groups) + return F.conv2d(input, weight, self.get_bias(), self.stride, + self.padding, self.dilation, self.groups) def forward(self, input: Tensor) -> Tensor: return self._conv_forward(input, self.get_weight()) + + class ActivationBalancer(torch.nn.Module): """ Modifies the backpropped derivatives of a function to try to encourage, for @@ -409,16 +364,12 @@ class ActivationBalancer(torch.nn.Module): we allow, before we start to modify the derivatives to prevent this. """ - - def __init__( - self, - channel_dim: int, - min_positive: float = 0.05, - max_positive: float = 0.95, - max_factor: float = 0.01, - min_abs: float = 0.2, - max_abs: float = 100.0, - ): + def __init__(self, channel_dim: int, + min_positive: float = 0.05, + max_positive: float = 0.95, + max_factor: float = 0.01, + min_abs: float = 0.2, + max_abs: float = 100.0): super(ActivationBalancer, self).__init__() self.channel_dim = channel_dim self.min_positive = min_positive @@ -428,15 +379,10 @@ class ActivationBalancer(torch.nn.Module): self.max_abs = max_abs def forward(self, x: Tensor) -> Tensor: - return ActivationBalancerFunction.apply( - x, - self.channel_dim, - self.min_positive, - self.max_positive, - self.max_factor, - self.min_abs, - self.max_abs, - ) + return ActivationBalancerFunction.apply(x, self.channel_dim, + self.min_positive, self.max_positive, + self.max_factor, self.min_abs, + self.max_abs) class DoubleSwishFunction(torch.autograd.Function): @@ -454,7 +400,6 @@ class DoubleSwishFunction(torch.autograd.Function): = double_swish(x) * (1-s(x)) + s(x) ... so we just need to remember s(x) but not x itself. """ - @staticmethod def forward(ctx, x: Tensor) -> Tensor: x = x.detach() @@ -466,17 +411,18 @@ class DoubleSwishFunction(torch.autograd.Function): @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: s, y = ctx.saved_tensors - return (y * (1 - s) + s) * y_grad - + return (y * (1-s) + s) * y_grad class DoubleSwish(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: """Return double-swish activation function which is an approximation to Swish(Swish(x)), - that we approximate closely with x * sigmoid(x-1). + that we approximate closely with x * sigmoid(x-1). """ return DoubleSwishFunction.apply(x) + + class ScaledEmbedding(nn.Module): r"""A simple lookup table that stores embeddings of a fixed dictionary and size. @@ -545,13 +491,8 @@ class ScaledEmbedding(nn.Module): [ 0.0000, 0.0000, 0.0000], [-0.1655, 0.9897, 0.0635]]]) """ - __constants__ = [ - "num_embeddings", - "embedding_dim", - "padding_idx", - "scale_grad_by_freq", - "sparse", - ] + __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', + 'scale_grad_by_freq', 'sparse'] num_embeddings: int embedding_dim: int @@ -560,40 +501,33 @@ class ScaledEmbedding(nn.Module): weight: Tensor sparse: bool - def __init__( - self, - num_embeddings: int, - embedding_dim: int, - padding_idx: Optional[int] = None, - scale_grad_by_freq: bool = False, - sparse: bool = False, - ) -> None: + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, + scale_grad_by_freq: bool = False, + sparse: bool = False) -> None: super(ScaledEmbedding, self).__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim if padding_idx is not None: if padding_idx > 0: - assert ( - padding_idx < self.num_embeddings - ), "Padding_idx must be within num_embeddings" + assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' elif padding_idx < 0: - assert ( - padding_idx >= -self.num_embeddings - ), "Padding_idx must be within num_embeddings" + assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' padding_idx = self.num_embeddings + padding_idx self.padding_idx = padding_idx self.scale_grad_by_freq = scale_grad_by_freq - self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() + self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() self.sparse = sparse self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) self.reset_parameters() + + def reset_parameters(self) -> None: std = 0.01 nn.init.normal_(self.weight, std=std) - nn.init.constant_(self.scale, torch.tensor(1.0 / std).log()) + nn.init.constant_(self.scale, torch.tensor(1.0/std).log()) if self.padding_idx is not None: with torch.no_grad(): @@ -603,37 +537,24 @@ class ScaledEmbedding(nn.Module): F = torch.nn.functional scale = self.scale.exp() if input.numel() < self.num_embeddings: - return ( - F.embedding( - input, - self.weight, - self.padding_idx, - None, - 2.0, # None, 2.0 relate to normalization - self.scale_grad_by_freq, - self.sparse, - ) - * scale - ) + return F.embedding( + input, self.weight, self.padding_idx, + None, 2.0, # None, 2.0 relate to normalization + self.scale_grad_by_freq, self.sparse) * scale else: return F.embedding( - input, - self.weight * scale, - self.padding_idx, - None, - 2.0, # None, 2.0 relates to normalization - self.scale_grad_by_freq, - self.sparse, - ) + input, self.weight * scale, self.padding_idx, + None, 2.0, # None, 2.0 relates to normalization + self.scale_grad_by_freq, self.sparse) def extra_repr(self) -> str: - s = "{num_embeddings}, {embedding_dim}, scale={scale}" + s = '{num_embeddings}, {embedding_dim}, scale={scale}' if self.padding_idx is not None: - s += ", padding_idx={padding_idx}" + s += ', padding_idx={padding_idx}' if self.scale_grad_by_freq is not False: - s += ", scale_grad_by_freq={scale_grad_by_freq}" + s += ', scale_grad_by_freq={scale_grad_by_freq}' if self.sparse is not False: - s += ", sparse=True" + s += ', sparse=True' return s.format(**self.__dict__) @@ -644,13 +565,8 @@ def _test_activation_balancer_sign(): x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) x = x.detach() x.requires_grad = True - m = ActivationBalancer( - channel_dim=0, - min_positive=0.05, - max_positive=0.95, - max_factor=0.2, - min_abs=0.0, - ) + m = ActivationBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95, + max_factor=0.2, min_abs=0.0) y_grad = torch.sign(torch.randn(probs.numel(), N)) @@ -660,22 +576,17 @@ def _test_activation_balancer_sign(): print("_test_activation_balancer_sign: y grad = ", y_grad) print("_test_activation_balancer_sign: x grad = ", x.grad) - def _test_activation_balancer_magnitude(): channel_dim = 0 magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True - m = ActivationBalancer( - channel_dim=0, - min_positive=0.0, - max_positive=1.0, - max_factor=0.2, - min_abs=0.2, - max_abs=0.8, - ) + m = ActivationBalancer(channel_dim=0, + min_positive=0.0, max_positive=1.0, + max_factor=0.2, + min_abs=0.2, max_abs=0.8) y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) @@ -710,7 +621,7 @@ def _test_double_swish_deriv(): torch.autograd.gradcheck(m, x) -if __name__ == "__main__": +if __name__ == '__main__': _test_activation_balancer_sign() _test_activation_balancer_magnitude() _test_basic_norm() diff --git a/egs/librispeech/ASR/pruned2_knowledge/train.py b/egs/librispeech/ASR/pruned2_knowledge/train.py index a60d15c3b..2f6840166 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/train.py +++ b/egs/librispeech/ASR/pruned2_knowledge/train.py @@ -78,7 +78,9 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def get_parser(): @@ -177,45 +179,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -555,16 +554,23 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -727,7 +733,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -827,7 +835,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 2 ** 22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py index 1df1650f3..2d5724d30 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py @@ -123,24 +123,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -208,7 +204,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -275,7 +272,9 @@ def decode_one_batch( value=LOG_EPS, ) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] if params.decoding_method == "fast_beam_search": @@ -290,7 +289,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -336,7 +338,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -409,7 +415,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -442,7 +450,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -485,7 +494,9 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -517,12 +528,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -545,12 +557,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -578,7 +591,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py index 008f40fb1..318cd5094 100644 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py @@ -272,9 +272,13 @@ class Emformer(EncoderInterface): # Caution: We assume the subsampling factor is 4! x_lens = (((x_lens - 1) >> 1) - 1) >> 1 - emformer_out, emformer_out_lens, states = self.model.infer(x, x_lens, states) + emformer_out, emformer_out_lens, states = self.model.infer( + x, x_lens, states + ) - if x.size(1) != (self.model.segment_length + self.model.right_context_length): + if x.size(1) != ( + self.model.segment_length + self.model.right_context_length + ): raise ValueError( "Incorrect input shape." f"{x.size(1)} vs {self.model.segment_length} + " diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py index 81afb523d..2375f5001 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py @@ -89,24 +89,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -137,7 +133,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) add_model_arguments(parser) @@ -173,12 +170,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -201,12 +199,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -234,7 +233,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -274,7 +273,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py index ed6848879..2f019bcdb 100644 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py @@ -122,7 +122,9 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py index 6b30d3be8..fed814f19 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py @@ -209,45 +209,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -569,7 +566,11 @@ def compute_loss( function enables autograd during computation; when it is False, it disables autograd. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -598,7 +599,9 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -779,7 +782,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -903,7 +908,8 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py index 830b37cfb..7af9cc3d7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py @@ -509,9 +509,9 @@ def greedy_search( y = logits.argmax().item() if y not in (blank_id, unk_id): hyp.append(y) - decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( - 1, context_size - ) + decoder_input = torch.tensor( + [hyp[-context_size:]], device=device + ).reshape(1, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) @@ -670,7 +670,9 @@ class HypothesisList(object): if use_max: old_hyp.log_prob = max(old_hyp.log_prob, hyp.log_prob) else: - torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) + torch.logaddexp( + old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob + ) else: self._data[key] = hyp @@ -686,7 +688,9 @@ class HypothesisList(object): Return the hypothesis that has the largest `log_prob`. """ if length_norm: - return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) + return max( + self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) + ) else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) @@ -888,7 +892,9 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -1082,7 +1088,9 @@ def beam_search( t = 0 B = HypothesisList() - B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0), use_max=use_max) + B.add( + Hypothesis(ys=[blank_id] * context_size, log_prob=0.0), use_max=use_max + ) max_sym_per_utt = 20000 @@ -1122,7 +1130,9 @@ def beam_search( cached_key += f"-t-{t}" if cached_key not in joint_cache: - logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1)) + logits = model.joiner( + current_encoder_out, decoder_out.unsqueeze(1) + ) # TODO(fangjun): Scale the blank posterior diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 03ad45f49..7b6338948 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -128,7 +128,11 @@ from beam_search import ( ) from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -167,11 +171,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -267,7 +269,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -380,7 +383,9 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] if ( @@ -445,7 +450,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -576,7 +584,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -609,7 +619,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -667,7 +678,9 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -705,7 +718,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -743,7 +757,9 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py index e522943c0..386248554 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py @@ -75,7 +75,9 @@ class DecodeStream(object): # encoder.streaming_forward self.done_frames: int = 0 - self.pad_length = (params.right_context + 2) * params.subsampling_factor + 3 + self.pad_length = ( + params.right_context + 2 + ) * params.subsampling_factor + 3 if params.decoding_method == "greedy_search": self.hyp = [params.blank_id] * params.context_size @@ -89,11 +91,13 @@ class DecodeStream(object): ) elif params.decoding_method == "fast_beam_search": # The rnnt_decoding_stream for fast_beam_search. - self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( - decoding_graph + self.rnnt_decoding_stream: k2.RnntDecodingStream = ( + k2.RnntDecodingStream(decoding_graph) ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) @property def done(self) -> bool: @@ -122,10 +126,13 @@ class DecodeStream(object): """Consume chunk_size frames of features""" chunk_length = chunk_size + self.pad_length - ret_length = min(self.num_frames - self.num_processed_frames, chunk_length) + ret_length = min( + self.num_frames - self.num_processed_frames, chunk_length + ) ret_features = self.features[ - self.num_processed_frames : self.num_processed_frames + ret_length # noqa + self.num_processed_frames : self.num_processed_frames # noqa + + ret_length ] self.num_processed_frames += chunk_size diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py index 72593173c..f4355e8a0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py @@ -92,7 +92,9 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) + embedding_out = F.pad( + embedding_out, pad=(self.context_size - 1, 0) + ) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/export.py b/egs/librispeech/ASR/pruned_transducer_stateless/export.py index 64708e524..b5a151878 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/export.py @@ -64,20 +64,17 @@ def get_parser(): "--epoch", type=int, default=28, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -108,7 +105,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -194,7 +192,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/model.py b/egs/librispeech/ASR/pruned_transducer_stateless/model.py index 2cca7fa27..73b651b3f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/model.py @@ -130,7 +130,9 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py index a42b63b9c..eb95827af 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py @@ -91,11 +91,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -120,12 +118,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -172,7 +168,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -224,9 +221,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -294,7 +292,9 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -381,7 +381,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py index 9e09200a1..dcf6dc42f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py @@ -166,10 +166,14 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths) + topk_log_probs, topk_indexes = ragged_log_probs[i].topk( + num_active_paths + ) with warnings.catch_warnings(): warnings.simplefilter("ignore") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py index a50b4d4f0..d2cae4f9f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py @@ -51,7 +51,11 @@ from streaming_beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import ( AttributeDict, setup_logger, @@ -90,11 +94,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -160,7 +162,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -266,7 +269,9 @@ def decode_one_chunk( ) if params.decoding_method == "greedy_search": - greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + greedy_search( + model=model, encoder_out=encoder_out, streams=decode_streams + ) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -286,7 +291,9 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -342,7 +349,9 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state(params.left_context, device=device) + initial_states = model.encoder.get_init_state( + params.left_context, device=device + ) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -413,7 +422,9 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) return {key: decode_results} @@ -449,7 +460,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -521,7 +533,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index dd0331a60..399b11a29 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -203,45 +203,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -565,7 +562,9 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -585,7 +584,9 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -776,7 +777,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -894,7 +897,8 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False @@ -952,7 +956,9 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 5e9428b60..b7c2010f7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -580,9 +580,9 @@ def greedy_search( if y not in (blank_id, unk_id): hyp.append(y) timestamp.append(t) - decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( - 1, context_size - ) + decoder_input = torch.tensor( + [hyp[-context_size:]], device=device + ).reshape(1, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.joiner.decoder_proj(decoder_out) @@ -775,7 +775,9 @@ class HypothesisList(object): key = hyp.key if key in self: old_hyp = self._data[key] # shallow copy - torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) + torch.logaddexp( + old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob + ) else: self._data[key] = hyp @@ -791,7 +793,9 @@ class HypothesisList(object): Return the hypothesis that has the largest `log_prob`. """ if length_norm: - return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) + return max( + self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) + ) else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) @@ -986,7 +990,9 @@ def modified_beam_search( logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) + log_probs = (logits / temperature).log_softmax( + dim=-1 + ) # (num_hyps, vocab_size) log_probs.add_(ys_log_probs) @@ -998,7 +1004,9 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -1668,7 +1676,9 @@ def fast_beam_search_with_nbest_rnn_rescoring( for rnn_scale in rnn_lm_scale_list: key = f"ngram_lm_scale_{n_scale}_rnn_lm_scale_{rnn_scale}" tot_scores = ( - am_scores.values + n_scale * ngram_lm_scores + rnn_scale * rnn_lm_scores + am_scores.values + + n_scale * ngram_lm_scores + + rnn_scale * rnn_lm_scores ) ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) max_indexes = ragged_tot_scores.argmax() @@ -1794,7 +1804,9 @@ def modified_beam_search_ngram_rescoring( logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) + log_probs = (logits / temperature).log_softmax( + dim=-1 + ) # (num_hyps, vocab_size) log_probs.add_(ys_log_probs) vocab_size = log_probs.size(-1) @@ -1804,7 +1816,9 @@ def modified_beam_search_ngram_rescoring( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -1827,7 +1841,9 @@ def modified_beam_search_ngram_rescoring( state_cost = hyp.state_cost # We only keep AM scores in new_hyp.log_prob - new_log_prob = topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale + new_log_prob = ( + topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale + ) new_hyp = Hypothesis( ys=new_ys, log_prob=new_log_prob, state_cost=state_cost @@ -1979,7 +1995,9 @@ def modified_beam_search_rnnlm_shallow_fusion( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) """ for all hyps with a non-blank new token, score this token. It is a little confusing here because this for-loop @@ -2014,7 +2032,10 @@ def modified_beam_search_rnnlm_shallow_fusion( # forward RNNLM to get new states and scores if len(token_list) != 0: tokens_to_score = ( - torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) + torch.tensor(token_list) + .to(torch.int64) + .to(device) + .reshape(-1, 1) ) hs = torch.cat(hs, dim=1).to(device) @@ -2046,7 +2067,9 @@ def modified_beam_search_rnnlm_shallow_fusion( ys.append(new_token) new_timestamp.append(t) - hyp_log_prob += lm_score[new_token] * lm_scale # add the lm score + hyp_log_prob += ( + lm_score[new_token] * lm_scale + ) # add the lm score lm_score = scores[count] state = ( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 34ff0d7e2..bc273d33b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -214,7 +214,10 @@ class Conformer(EncoderInterface): NOTE: the returned tensors are on the given device. """ - if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context: + if ( + len(self._init_state) == 2 + and self._init_state[0].size(1) == left_context + ): # Note: It is OK to share the init state as it is # not going to be modified by the model return self._init_state @@ -436,7 +439,9 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -454,7 +459,9 @@ class ConformerEncoderLayer(nn.Module): ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) - self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal) + self.conv_module = ConvolutionModule( + d_model, cnn_module_kernel, causal=causal + ) self.norm_final = BasicNorm(d_model) @@ -520,7 +527,9 @@ class ConformerEncoderLayer(nn.Module): src = src + self.dropout(src_att) # convolution module - conv, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask) + conv, _ = self.conv_module( + src, src_key_padding_mask=src_key_padding_mask + ) src = src + self.dropout(conv) # feed forward module @@ -776,7 +785,9 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() if is_jit_tracing(): @@ -800,7 +811,9 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x_size_1 * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -1114,9 +1127,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -1185,25 +1198,33 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor" - " instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -1243,15 +1264,23 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) # compute matrix b and matrix d - matrix_bd = torch.matmul(q_with_bias_v, p) # (batch, head, time1, 2*time1-1) + matrix_bd = torch.matmul( + q_with_bias_v, p + ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd, left_context) - attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) + attn_output_weights = ( + matrix_ac + matrix_bd + ) # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) if not is_jit_tracing(): assert list(attn_output_weights.size()) == [ @@ -1293,17 +1322,21 @@ class RelPositionMultiheadAttention(nn.Module): ): if attn_mask.size(0) != 1: attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len) - combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) - else: - # attn_mask.shape == (1, tgt_len, src_len) - combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( + combined_mask = attn_mask | key_padding_mask.unsqueeze( 1 ).unsqueeze(2) + else: + # attn_mask.shape == (1, tgt_len, src_len) + combined_mask = attn_mask.unsqueeze( + 0 + ) | key_padding_mask.unsqueeze(1).unsqueeze(2) attn_output_weights = attn_output_weights.view( bsz, num_heads, tgt_len, src_len ) - attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) + attn_output_weights = attn_output_weights.masked_fill( + combined_mask, 0.0 + ) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, src_len ) @@ -1322,9 +1355,13 @@ class RelPositionMultiheadAttention(nn.Module): ] attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -1461,12 +1498,16 @@ class ConvolutionModule(nn.Module): # manualy padding self.lorder zeros to the left x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) else: - assert not self.training, "Cache should be None in training time" + assert ( + not self.training + ), "Cache should be None in training time" assert cache.size(0) == self.lorder x = torch.cat([cache.permute(1, 2, 0), x], dim=2) if right_context > 0: cache = x.permute(2, 0, 1)[ - -(self.lorder + right_context) : (-right_context), # noqa + -(self.lorder + right_context) : ( # noqa + -right_context + ), ..., ] else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index 32cd53be3..979a0e02e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -132,7 +132,11 @@ from beam_search import ( ) from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -173,11 +177,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -273,7 +275,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -394,7 +397,9 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] @@ -460,7 +465,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -506,7 +514,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } elif "fast_beam_search" in params.decoding_method: key = f"beam_{params.beam}_" @@ -596,7 +608,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -629,7 +643,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -685,7 +700,9 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -723,7 +740,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -761,7 +779,9 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index b59928103..ba91302ce 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -107,11 +107,15 @@ class Decoder(nn.Module): # This is for exporting to PNNX via ONNX embedding_out = self.embedding(y) else: - embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) + embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze( + -1 + ) if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad: - embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) + embedding_out = F.pad( + embedding_out, pad=(self.context_size - 1, 0) + ) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py index 90367bd03..f1a8ea589 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py @@ -51,7 +51,11 @@ import sentencepiece as spm import torch from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import str2bool @@ -83,11 +87,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -118,7 +120,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -170,7 +173,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -218,7 +222,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py index 1954f4724..6a9d08033 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py @@ -60,7 +60,9 @@ class Joiner(nn.Module): assert encoder_out.shape == decoder_out.shape if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) + logit = self.encoder_proj(encoder_out) + self.decoder_proj( + decoder_out + ) else: logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 272d06c37..417c391d9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -66,7 +66,9 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_speed=0.5 + ) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) def forward( @@ -150,7 +152,9 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py index 2d7f557ad..041a81f45 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -72,11 +72,17 @@ class Eve(Optimizer): if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + raise ValueError( + "Invalid beta parameter at index 0: {}".format(betas[0]) + ) if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + raise ValueError( + "Invalid beta parameter at index 1: {}".format(betas[1]) + ) if not 0 <= weight_decay <= 0.1: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) if not 0 < target_rms <= 10.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) defaults = dict( @@ -112,7 +118,9 @@ class Eve(Optimizer): # Perform optimization step grad = p.grad if grad.is_sparse: - raise RuntimeError("AdamW does not support sparse gradients") + raise RuntimeError( + "AdamW does not support sparse gradients" + ) state = self.state[p] @@ -139,7 +147,7 @@ class Eve(Optimizer): # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( + denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( group["eps"] ) @@ -150,7 +158,9 @@ class Eve(Optimizer): if p.numel() > 1: # avoid applying this weight-decay on "scaling factors" # (which are scalar). - is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) + is_above_target_rms = p.norm() > ( + target_rms * (p.numel() ** 0.5) + ) p.mul_(1 - (weight_decay * is_above_target_rms)) p.addcdiv_(exp_avg, denom, value=-step_size) @@ -170,14 +180,18 @@ class LRScheduler(object): def __init__(self, optimizer: Optimizer, verbose: bool = False): # Attach optimizer if not isinstance(optimizer, Optimizer): - raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) + raise TypeError( + "{} is not an Optimizer".format(type(optimizer).__name__) + ) self.optimizer = optimizer self.verbose = verbose for group in optimizer.param_groups: group.setdefault("initial_lr", group["lr"]) - self.base_lrs = [group["initial_lr"] for group in optimizer.param_groups] + self.base_lrs = [ + group["initial_lr"] for group in optimizer.param_groups + ] self.epoch = 0 self.batch = 0 @@ -285,9 +299,10 @@ class Eden(LRScheduler): def get_lr(self): factor = ( - (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 + (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 ) ** -0.25 * ( - ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 + ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) + ** -0.25 ) return [x * factor for x in self.base_lrs] diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py index 58de6875f..f52cb22ab 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py @@ -91,11 +91,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -120,12 +118,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -172,7 +168,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -225,9 +222,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -295,7 +293,9 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -382,7 +382,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index f671e97b1..8c572a9ef 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -89,7 +89,9 @@ class ActivationBalancerFunction(torch.autograd.Function): below_threshold = mean_abs < min_abs above_threshold = mean_abs > max_abs - ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) + ctx.save_for_backward( + factor, xgt0, below_threshold, above_threshold + ) ctx.max_factor = max_factor ctx.sum_dims = sum_dims return x @@ -135,7 +137,7 @@ class GradientFilterFunction(torch.autograd.Function): eps = 1.0e-20 dim = ctx.batch_dim norm_dims = [d for d in range(x_grad.ndim) if d != dim] - norm_of_batch = (x_grad**2).mean(dim=norm_dims, keepdim=True).sqrt() + norm_of_batch = (x_grad ** 2).mean(dim=norm_dims, keepdim=True).sqrt() median_norm = norm_of_batch.median() cutoff = median_norm * ctx.threshold @@ -227,7 +229,8 @@ class BasicNorm(torch.nn.Module): if not is_jit_tracing(): assert x.shape[self.channel_dim] == self.num_channels scales = ( - torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps.exp() + torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + + self.eps.exp() ) ** -0.5 return x * scales @@ -279,12 +282,12 @@ class ScaledLinear(nn.Linear): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3**0.5) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) + scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -298,7 +301,9 @@ class ScaledLinear(nn.Linear): return self.bias * self.bias_scale.exp() def forward(self, input: Tensor) -> Tensor: - return torch.nn.functional.linear(input, self.get_weight(), self.get_bias()) + return torch.nn.functional.linear( + input, self.get_weight(), self.get_bias() + ) class ScaledConv1d(nn.Conv1d): @@ -326,12 +331,12 @@ class ScaledConv1d(nn.Conv1d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3**0.5) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) + scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -395,12 +400,12 @@ class ScaledConv2d(nn.Conv2d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3**0.5) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) + scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -471,7 +476,9 @@ class ScaledLSTM(nn.LSTM): setattr(self, scale_name, param) self._scales.append(param) - self.grad_filter = GradientFilter(batch_dim=1, threshold=grad_norm_threshold) + self.grad_filter = GradientFilter( + batch_dim=1, threshold=grad_norm_threshold + ) self._reset_parameters( initial_speed @@ -479,8 +486,8 @@ class ScaledLSTM(nn.LSTM): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3**0.5) * std - scale = self.hidden_size**-0.5 + a = (3 ** 0.5) * std + scale = self.hidden_size ** -0.5 v = scale / std for idx, name in enumerate(self._flat_weights_names): if "weight" in name: @@ -552,11 +559,15 @@ class ScaledLSTM(nn.LSTM): """Get scaled weights, and resets their data pointer.""" flat_weights = [] for idx in range(len(self._flat_weights_names)): - flat_weights.append(self._flat_weights[idx] * self._scales[idx].exp()) + flat_weights.append( + self._flat_weights[idx] * self._scales[idx].exp() + ) self._flatten_parameters(flat_weights) return flat_weights - def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None): + def forward( + self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None + ): # This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa # The change for calling `_VF.lstm()` is: # self._flat_weights -> self._get_flat_weights() @@ -904,7 +915,9 @@ def _test_activation_balancer_sign(): def _test_activation_balancer_magnitude(): magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( + -1 + ) x = x.detach() x.requires_grad = True m = ActivationBalancer( @@ -934,8 +947,8 @@ def _test_basic_norm(): y = m(x) assert y.shape == x.shape - x_rms = (x**2).mean().sqrt() - y_rms = (y**2).mean().sqrt() + x_rms = (x ** 2).mean().sqrt() + y_rms = (y ** 2).mean().sqrt() print("x rms = ", x_rms) print("y rms = ", y_rms) assert y_rms < x_rms @@ -988,18 +1001,17 @@ def _test_grad_filter(): ) print( - "_test_grad_filter: for gradient norms, the first element > median *" - " threshold ", # noqa + "_test_grad_filter: for gradient norms, the first element > median * threshold ", # noqa i % 2 == 1, ) print( "_test_grad_filter: x_out_grad norm = ", - (x_out_grad**2).mean(dim=(0, 2)).sqrt(), + (x_out_grad ** 2).mean(dim=(0, 2)).sqrt(), ) print( "_test_grad_filter: x.grad norm = ", - (x.grad**2).mean(dim=(0, 2)).sqrt(), + (x.grad ** 2).mean(dim=(0, 2)).sqrt(), ) print("_test_grad_filter: w_out_grad = ", w_out_grad) print("_test_grad_filter: w.grad = ", w.grad) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py index e6e0fb1c8..9bcd2f9f9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py @@ -153,7 +153,9 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner(current_encoder_out, decoder_out, project_input=False) + logits = model.joiner( + current_encoder_out, decoder_out, project_input=False + ) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -170,10 +172,14 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths) + topk_log_probs, topk_indexes = ragged_log_probs[i].topk( + num_active_paths + ) with warnings.catch_warnings(): warnings.simplefilter("ignore") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py index 0139863a1..d76a03946 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py @@ -51,7 +51,11 @@ from streaming_beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import ( AttributeDict, setup_logger, @@ -90,11 +94,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -160,7 +162,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -268,7 +271,9 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + greedy_search( + model=model, encoder_out=encoder_out, streams=decode_streams + ) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -288,7 +293,9 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -344,7 +351,9 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state(params.left_context, device=device) + initial_states = model.encoder.get_init_state( + params.left_context, device=device + ) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -416,7 +425,9 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) return {key: decode_results} @@ -451,7 +462,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -524,7 +536,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 623bdd51a..1947834bf 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -96,7 +96,9 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def add_model_arguments(parser: argparse.ArgumentParser): @@ -208,7 +210,8 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need to be changed.", + help="The initial learning rate. This value should not need to " + "be changed.", ) parser.add_argument( @@ -231,45 +234,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -634,7 +634,9 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -647,9 +649,14 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -660,7 +667,9 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -828,7 +837,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -952,7 +963,8 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py index 5e81aef07..1df7f9ee5 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py @@ -27,7 +27,10 @@ from lhotse.dataset import ( K2SpeechRecognitionDataset, SpecAugment, ) -from lhotse.dataset.input_strategies import OnTheFlyFeatures, PrecomputedFeatures +from lhotse.dataset.input_strategies import ( + OnTheFlyFeatures, + PrecomputedFeatures, +) from torch.utils.data import DataLoader from icefall.utils import str2bool @@ -41,69 +44,59 @@ class AsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description=( - "These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc." - ), + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", ) group.add_argument( "--max-duration", type=int, default=200.0, - help=( - "Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM." - ), + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help=( - "When enabled, the batches will come from buckets of " - "similar duration (saves padding frames)." - ), + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", ) group.add_argument( "--num-buckets", type=int, default=30, - help=( - "The number of buckets for the DynamicBucketingSampler. " - "(you might want to increase it for larger datasets)." - ), + help="The number of buckets for the DynamicBucketingSampler. " + "(you might want to increase it for larger datasets).", ) group.add_argument( "--shuffle", type=str2bool, default=True, - help=( - "When enabled (=default), the examples will be shuffled for each epoch." - ), + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help=( - "When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it." - ), + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that collect the batches.", + help="The number of training dataloader workers that " + "collect the batches.", ) group.add_argument( @@ -124,22 +117,18 @@ class AsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help=( - "Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp." - ), + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help=( - "When enabled, select noise from MUSAN and mix it" - "with training dataset. " - ), + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", ) group.add_argument( @@ -153,11 +142,9 @@ class AsrDataModule: "--on-the-fly-feats", type=str2bool, default=False, - help=( - "When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available. Used only in dev/test CutSet" - ), + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available. Used only in dev/test CutSet", ) def train_dataloaders( @@ -180,7 +167,9 @@ class AsrDataModule: if cuts_musan is not None: logging.info("Enable MUSAN") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix( + cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True + ) ) else: logging.info("Disable MUSAN") @@ -189,7 +178,9 @@ class AsrDataModule: if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + logging.info( + f"Time warp factor: {self.args.spec_aug_time_warp_factor}" + ) input_transforms.append( SpecAugment( time_warp_factor=self.args.spec_aug_time_warp_factor, @@ -259,7 +250,9 @@ class AsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), return_cuts=self.args.return_cuts, ) else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py index 66c8e30ba..5784a78ba 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py @@ -79,7 +79,11 @@ from gigaspeech import GigaSpeech from gigaspeech_scoring import asr_text_post_processing from train import get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import ( AttributeDict, setup_logger, @@ -116,11 +120,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -190,7 +192,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -277,7 +280,9 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] if params.decoding_method == "fast_beam_search": @@ -307,7 +312,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -351,11 +359,21 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } elif params.decoding_method == "fast_beam_search_nbest_oracle": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}_num_paths_{params.num_paths}_nbest_scale_{params.nbest_scale}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}_" + f"num_paths_{params.num_paths}_" + f"nbest_scale_{params.nbest_scale}" + ): hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -428,7 +446,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -461,7 +481,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -511,7 +532,9 @@ def main(): params.suffix += f"-num-paths-{params.num_paths}" params.suffix += f"-nbest-scale-{params.nbest_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -544,7 +567,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index d90497e26..8025d6be1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -120,7 +120,11 @@ from beam_search import ( from librispeech import LibriSpeech from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.lexicon import Lexicon from icefall.rnn_lm.model import RnnLmModel from icefall.utils import ( @@ -163,11 +167,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -263,7 +265,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -475,7 +478,9 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] @@ -545,7 +550,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -638,11 +646,21 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}temperature_{params.temperature}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + f"temperature_{params.temperature}" + ): hyps } elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}temperature_{params.temperature}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + f"temperature_{params.temperature}" + ): hyps } elif params.decoding_method in [ "fast_beam_search_with_nbest_rescoring", @@ -672,7 +690,12 @@ def decode_one_batch( key += f"_ngram_lm_scale_{params.ngram_lm_scale}" return {key: hyps} else: - return {f"beam_size_{params.beam_size}_temperature_{params.temperature}": hyps} + return { + ( + f"beam_size_{params.beam_size}_" + f"temperature_{params.temperature}" + ): hyps + } def decode_dataset( @@ -756,7 +779,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -789,7 +814,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -913,7 +939,9 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) params.suffix += f"-temperature-{params.temperature}" else: params.suffix += f"-context-{params.context_size}" @@ -953,7 +981,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -1003,10 +1032,15 @@ def main(): word_table=word_table, device=device, ) - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) logging.info(f"G properties_str: {G.properties_str}") rnn_lm_model = None - if params.decoding_method == "fast_beam_search_with_nbest_rnn_rescoring": + if ( + params.decoding_method + == "fast_beam_search_with_nbest_rnn_rescoring" + ): rnn_lm_model = RnnLmModel( vocab_size=params.vocab_size, embedding_dim=params.rnn_lm_embedding_dim, @@ -1031,7 +1065,9 @@ def main(): rnn_lm_model.eval() else: word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) rnn_lm_model = None else: decoding_graph = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py index dcf65e937..47217ba05 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py @@ -128,7 +128,11 @@ import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import str2bool @@ -160,11 +164,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -233,7 +235,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -506,9 +509,13 @@ def export_joiner_model_onnx( - projected_decoder_out: a tensor of shape (N, joiner_dim) """ - encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") + encoder_proj_filename = str(joiner_filename).replace( + ".onnx", "_encoder_proj.onnx" + ) - decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") + decoder_proj_filename = str(joiner_filename).replace( + ".onnx", "_decoder_proj.onnx" + ) encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] @@ -609,7 +616,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -707,7 +715,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py index 598434f54..36f32c6b3 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py @@ -52,14 +52,18 @@ class GigaSpeech: ) pattern = re.compile(r"gigaspeech_cuts_XL.([0-9]+).jsonl.gz") - idx_filenames = [(int(pattern.search(f).group(1)), f) for f in filenames] + idx_filenames = [ + (int(pattern.search(f).group(1)), f) for f in filenames + ] idx_filenames = sorted(idx_filenames, key=lambda x: x[0]) sorted_filenames = [f[1] for f in idx_filenames] logging.info(f"Loading {len(sorted_filenames)} splits") - return lhotse.combine(lhotse.load_manifest_lazy(p) for p in sorted_filenames) + return lhotse.combine( + lhotse.load_manifest_lazy(p) for p in sorted_filenames + ) def train_L_cuts(self) -> CutSet: f = self.manifest_dir / "gigaspeech_cuts_L_raw.jsonl.gz" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py index 108915389..162f8c7db 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py @@ -104,12 +104,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -144,9 +142,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -331,7 +330,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py index d45f6dadc..7852f84e9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py @@ -84,7 +84,9 @@ class Transducer(nn.Module): self.decoder_giga = decoder_giga self.joiner_giga = joiner_giga - self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_speed=0.5 + ) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) if decoder_giga is not None: @@ -188,7 +190,9 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) boundary[:, 2] = y_lens boundary[:, 3] = encoder_out_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py index 163d737e3..d03d1d7ef 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py @@ -203,7 +203,9 @@ def test_joiner( ) # Now test encoder_proj - joiner_encoder_proj_inputs = {encoder_proj_input_name: encoder_out.numpy()} + joiner_encoder_proj_inputs = { + encoder_proj_input_name: encoder_out.numpy() + } joiner_encoder_proj_out = joiner_encoder_proj_session.run( [encoder_proj_output_name], joiner_encoder_proj_inputs )[0] @@ -212,10 +214,16 @@ def test_joiner( torch_joiner_encoder_proj_out = model.joiner.encoder_proj(encoder_out) assert torch.allclose( joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5 - ), ((joiner_encoder_proj_out - torch_joiner_encoder_proj_out).abs().max()) + ), ( + (joiner_encoder_proj_out - torch_joiner_encoder_proj_out) + .abs() + .max() + ) # Now test decoder_proj - joiner_decoder_proj_inputs = {decoder_proj_input_name: decoder_out.numpy()} + joiner_decoder_proj_inputs = { + decoder_proj_input_name: decoder_out.numpy() + } joiner_decoder_proj_out = joiner_decoder_proj_session.run( [decoder_proj_output_name], joiner_decoder_proj_inputs )[0] @@ -224,7 +232,11 @@ def test_joiner( torch_joiner_decoder_proj_out = model.joiner.decoder_proj(decoder_out) assert torch.allclose( joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5 - ), ((joiner_decoder_proj_out - torch_joiner_decoder_proj_out).abs().max()) + ), ( + (joiner_decoder_proj_out - torch_joiner_decoder_proj_out) + .abs() + .max() + ) @torch.no_grad() @@ -276,7 +288,9 @@ def main(): if __name__ == "__main__": torch.manual_seed(20220727) - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py index 11597aa49..ea5d4e674 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py @@ -102,12 +102,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -142,9 +140,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -192,7 +191,11 @@ def greedy_search( projected_encoder_out = joiner_encoder_proj.run( [joiner_encoder_proj.get_outputs()[0].name], - {joiner_encoder_proj.get_inputs()[0].name: packed_encoder_out.data.numpy()}, + { + joiner_encoder_proj.get_inputs()[ + 0 + ].name: packed_encoder_out.data.numpy() + }, )[0] blank_id = 0 # hard-code to 0 @@ -379,7 +382,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py index 849d6cf4e..19b636a23 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py @@ -100,11 +100,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -129,12 +127,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -181,7 +177,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -234,9 +231,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -304,7 +302,9 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -391,7 +391,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py index 85d87f8f2..1e6022b57 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py @@ -234,7 +234,9 @@ def scaled_lstm_to_lstm(scaled_lstm: ScaledLSTM) -> nn.LSTM: assert lstm._flat_weights_names == scaled_lstm._flat_weights_names for idx in range(len(scaled_lstm._flat_weights_names)): - scaled_weight = scaled_lstm._flat_weights[idx] * scaled_lstm._scales[idx].exp() + scaled_weight = ( + scaled_lstm._flat_weights[idx] * scaled_lstm._scales[idx].exp() + ) lstm._flat_weights[idx].data.copy_(scaled_weight) return lstm @@ -249,10 +251,12 @@ def get_submodule(model, target): mod: torch.nn.Module = model for item in atoms: if not hasattr(mod, item): - raise AttributeError(mod._get_name() + " has no attribute `" + item + "`") + raise AttributeError( + mod._get_name() + " has no " "attribute `" + item + "`" + ) mod = getattr(mod, item) if not isinstance(mod, torch.nn.Module): - raise AttributeError("`" + item + "` is not an nn.Module") + raise AttributeError("`" + item + "` is not " "an nn.Module") return mod diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py index 41a712498..10bb44e00 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py @@ -52,7 +52,11 @@ from streaming_beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import ( AttributeDict, setup_logger, @@ -91,11 +95,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -161,7 +163,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -269,7 +272,9 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + greedy_search( + model=model, encoder_out=encoder_out, streams=decode_streams + ) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -289,7 +294,9 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -345,7 +352,9 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state(params.left_context, device=device) + initial_states = model.encoder.get_init_state( + params.left_context, device=device + ) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -417,7 +426,9 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) return {key: decode_results} @@ -450,7 +461,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -523,7 +535,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py index 598fcf344..66ffbd3ec 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py @@ -90,7 +90,9 @@ def test_conv2d_subsampling(): onnx_y = torch.from_numpy(onnx_y) torch_y = jit_model(x) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() + assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( + (onnx_y - torch_y).abs().max() + ) os.remove(filename) @@ -145,7 +147,9 @@ def test_rel_pos(): onnx_pos_emb = torch.from_numpy(onnx_pos_emb) torch_y, torch_pos_emb = jit_model(x) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() + assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( + (onnx_y - torch_y).abs().max() + ) assert torch.allclose(onnx_pos_emb, torch_pos_emb, atol=1e-05), ( (onnx_pos_emb - torch_pos_emb).abs().max() @@ -193,7 +197,9 @@ def test_conformer_encoder_layer(): encoder_layer.eval() encoder_layer = convert_scaled_to_non_scaled(encoder_layer, inplace=True) - jit_model = torch.jit.trace(encoder_layer, (x, pos_emb, src_key_padding_mask)) + jit_model = torch.jit.trace( + encoder_layer, (x, pos_emb, src_key_padding_mask) + ) torch.onnx.export( encoder_layer, @@ -230,7 +236,9 @@ def test_conformer_encoder_layer(): onnx_y = torch.from_numpy(onnx_y) torch_y = jit_model(x, pos_emb, src_key_padding_mask) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() + assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( + (onnx_y - torch_y).abs().max() + ) print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape) @@ -314,7 +322,9 @@ def test_conformer_encoder(): onnx_y = torch.from_numpy(onnx_y) torch_y = jit_model(x, pos_emb, src_key_padding_mask) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() + assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( + (onnx_y - torch_y).abs().max() + ) print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape) @@ -369,7 +379,9 @@ def test_conformer(): onnx_y_lens = torch.from_numpy(onnx_y_lens) torch_y, torch_y_lens = jit_model(x, x_lens) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() + assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( + (onnx_y - torch_y).abs().max() + ) assert torch.allclose(onnx_y_lens, torch_y_lens, atol=1e-05), ( (onnx_y_lens - torch_y_lens).abs().max() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py index 6724343dd..44e96644a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py @@ -92,7 +92,9 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def add_model_arguments(parser: argparse.ArgumentParser): @@ -161,7 +163,8 @@ def get_parser(): "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. " + "Otherwise, use 100h subset.", ) parser.add_argument( @@ -211,7 +214,8 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need to be changed.", + help="The initial learning rate. This value should not need " + "to be changed.", ) parser.add_argument( @@ -234,45 +238,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -671,7 +672,9 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -684,9 +687,14 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -697,7 +705,9 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -909,7 +919,9 @@ def train_one_epoch( f"train/current_{prefix}_", params.batch_idx_train, ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) libri_tot_loss.write_summary( tb_writer, "train/libri_tot_", params.batch_idx_train ) @@ -955,7 +967,8 @@ def filter_short_and_long_utterances( # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False @@ -1096,7 +1109,9 @@ def run(rank, world_size, args): train_giga_cuts = train_giga_cuts.repeat(times=None) if args.enable_musan: - cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") + cuts_musan = load_manifest( + Path(args.manifest_dir) / "musan_cuts.jsonl.gz" + ) else: cuts_musan = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 69cfcd298..4f043e5a6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -197,24 +197,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -310,7 +306,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -430,7 +427,9 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) if ( params.decoding_method == "fast_beam_search" @@ -486,7 +485,10 @@ def decode_one_batch( nbest_scale=params.nbest_scale, return_timestamps=True, ) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): res = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -564,7 +566,9 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]: +) -> Dict[ + str, List[Tuple[str, List[str], List[str], List[float], List[float]]] +]: """Decode dataset. Args: @@ -639,7 +643,9 @@ def decode_dataset( cut_ids, hyps, texts, timestamps_hyp, timestamps_ref ): ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words, time_ref, time_hyp)) + this_batch.append( + (cut_id, ref_words, hyp_words, time_ref, time_hyp) + ) results[name].extend(this_batch) @@ -648,7 +654,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -686,7 +694,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -713,7 +722,9 @@ def save_results( note = "" logging.info(s) - s = "\nFor {}, symbol-delay of different settings are:\n".format(test_set_name) + s = "\nFor {}, symbol-delay of different settings are:\n".format( + test_set_name + ) note = "\tbest for {}".format(test_set_name) for key, val in test_set_delays: s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note) @@ -762,7 +773,9 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -799,12 +812,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -827,12 +841,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -860,7 +875,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -887,7 +902,9 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py index bd5801a78..ce7518ceb 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py @@ -89,24 +89,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -137,7 +133,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -186,12 +183,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -214,12 +212,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -247,7 +246,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -283,7 +282,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py index a28e52c78..7af9ea9b8 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py @@ -96,24 +96,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -179,7 +175,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -287,7 +284,9 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + greedy_search( + model=model, encoder_out=encoder_out, streams=decode_streams + ) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -307,7 +306,9 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -363,7 +364,9 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state(params.left_context, device=device) + initial_states = model.encoder.get_init_state( + params.left_context, device=device + ) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -435,7 +438,9 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) return {key: decode_results} @@ -468,7 +473,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -541,12 +547,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -569,12 +576,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -602,7 +610,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index 76785a845..cf32e565b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -101,7 +101,9 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def add_model_arguments(parser: argparse.ArgumentParser): @@ -237,45 +239,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -622,7 +621,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -662,7 +665,9 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -675,9 +680,14 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -688,7 +698,9 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -867,7 +879,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -999,7 +1013,8 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index 8499651d7..427b06294 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -214,7 +214,10 @@ class Conformer(EncoderInterface): (num_encoder_layers, cnn_module_kernel - 1, encoder_dim). NOTE: the returned tensors are on the given device. """ - if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context: + if ( + len(self._init_state) == 2 + and self._init_state[0].size(1) == left_context + ): # Note: It is OK to share the init state as it is # not going to be modified by the model return self._init_state @@ -436,7 +439,9 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -454,7 +459,9 @@ class ConformerEncoderLayer(nn.Module): ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) - self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal) + self.conv_module = ConvolutionModule( + d_model, cnn_module_kernel, causal=causal + ) self.norm_final = BasicNorm(d_model) @@ -520,7 +527,9 @@ class ConformerEncoderLayer(nn.Module): src = src + self.dropout(src_att) # convolution module - conv, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask) + conv, _ = self.conv_module( + src, src_key_padding_mask=src_key_padding_mask + ) src = src + self.dropout(conv) # feed forward module @@ -793,7 +802,9 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -809,7 +820,9 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x_size_1 * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -835,7 +848,9 @@ class RelPositionalEncoding(torch.nn.Module): pe = torch.cat([pe_positive, pe_negative], dim=1) self.pe = pe.to(device=x.device, dtype=x.dtype) - def forward(self, x: torch.Tensor, left_context: int = 0) -> Tuple[Tensor, Tensor]: + def forward( + self, x: torch.Tensor, left_context: int = 0 + ) -> Tuple[Tensor, Tensor]: """Add positional encoding. Args: @@ -1103,9 +1118,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -1174,25 +1189,33 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor" - " instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -1230,15 +1253,23 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) # compute matrix b and matrix d - matrix_bd = torch.matmul(q_with_bias_v, p) # (batch, head, time1, 2*time1-1) + matrix_bd = torch.matmul( + q_with_bias_v, p + ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd, left_context) - attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) + attn_output_weights = ( + matrix_ac + matrix_bd + ) # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -1279,17 +1310,21 @@ class RelPositionMultiheadAttention(nn.Module): ): if attn_mask.size(0) != 1: attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len) - combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) - else: - # attn_mask.shape == (1, tgt_len, src_len) - combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( + combined_mask = attn_mask | key_padding_mask.unsqueeze( 1 ).unsqueeze(2) + else: + # attn_mask.shape == (1, tgt_len, src_len) + combined_mask = attn_mask.unsqueeze( + 0 + ) | key_padding_mask.unsqueeze(1).unsqueeze(2) attn_output_weights = attn_output_weights.view( bsz, num_heads, tgt_len, src_len ) - attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) + attn_output_weights = attn_output_weights.masked_fill( + combined_mask, 0.0 + ) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, src_len ) @@ -1301,9 +1336,13 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -1442,12 +1481,16 @@ class ConvolutionModule(nn.Module): # manualy padding self.lorder zeros to the left x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) else: - assert not self.training, "Cache should be None in training time" + assert ( + not self.training + ), "Cache should be None in training time" assert cache.size(0) == self.lorder x = torch.cat([cache.permute(1, 2, 0), x], dim=2) if right_context > 0: cache = x.permute(2, 0, 1)[ - -(self.lorder + right_context) : (-right_context), # noqa + -(self.lorder + right_context) : ( # noqa + -right_context + ), ..., ] else: @@ -1623,7 +1666,9 @@ class RandomCombine(nn.Module): self.stddev = stddev self.final_log_weight = ( - torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)) + torch.tensor( + (final_weight / (1 - final_weight)) * (self.num_inputs - 1) + ) .log() .item() ) @@ -1720,14 +1765,16 @@ class RandomCombine(nn.Module): # final contains self.num_inputs - 1 in all elements final = torch.full((num_frames,), self.num_inputs - 1, device=device) # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. - nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device) + nonfinal = torch.randint( + self.num_inputs - 1, (num_frames,), device=device + ) indexes = torch.where( torch.rand(num_frames, device=device) < final_prob, final, nonfinal ) - ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to( - dtype=dtype - ) + ans = torch.nn.functional.one_hot( + indexes, num_classes=self.num_inputs + ).to(dtype=dtype) return ans def _get_random_mixed_weights( @@ -1757,8 +1804,7 @@ class RandomCombine(nn.Module): def _test_random_combine(final_weight: float, pure_prob: float, stddev: float): print( - f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}," - f" stddev={stddev}" + f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}" ) num_inputs = 3 num_channels = 50 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index f462cc42f..22bcdd88e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -179,24 +179,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -307,7 +303,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -480,7 +477,9 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] @@ -546,7 +545,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -694,7 +696,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -727,7 +731,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -782,7 +787,9 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -821,12 +828,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -849,12 +857,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -882,7 +891,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -928,7 +937,9 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py index a739c17bc..b2e5b430e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py @@ -89,24 +89,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -137,7 +133,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -184,12 +181,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -212,12 +210,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -245,7 +244,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -281,7 +280,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py index e2da0da4c..1e100fcbd 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py @@ -89,11 +89,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -118,12 +116,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -170,7 +166,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -201,9 +198,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -266,11 +264,15 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lengths + ) num_waves = encoder_out.size(0) hyps = [] @@ -342,7 +344,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py index 59a0e8fa2..6fee9483e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py @@ -96,24 +96,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -179,7 +175,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -287,7 +284,9 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + greedy_search( + model=model, encoder_out=encoder_out, streams=decode_streams + ) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -307,7 +306,9 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -363,7 +364,9 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state(params.left_context, device=device) + initial_states = model.encoder.get_init_state( + params.left_context, device=device + ) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -435,7 +438,9 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) return {key: decode_results} @@ -468,7 +473,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -541,12 +547,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -569,12 +576,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -602,7 +610,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py index 75696d61b..179d9372e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py @@ -89,7 +89,9 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def add_model_arguments(parser: argparse.ArgumentParser): @@ -246,7 +248,8 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need to be changed.", + help="The initial learning rate. This value should not need " + "to be changed.", ) parser.add_argument( @@ -269,45 +272,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -645,7 +645,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -686,7 +690,9 @@ def compute_loss( # If the batch contains more than 10 utterances AND # if either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -699,9 +705,14 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -712,7 +723,9 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -895,7 +908,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -1008,7 +1023,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 2 ** 22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1030,7 +1045,8 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py index 40ad61fd4..53788b3f7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py @@ -90,7 +90,10 @@ class Conformer(EncoderInterface): output_layers = [] if middle_output_layer is not None: - assert middle_output_layer >= 0 and middle_output_layer < num_encoder_layers + assert ( + middle_output_layer >= 0 + and middle_output_layer < num_encoder_layers + ) output_layers.append(middle_output_layer) # The last layer is always needed. @@ -175,7 +178,9 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -357,7 +362,9 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -372,7 +379,9 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -647,9 +656,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -718,25 +727,33 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor" - " instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -773,7 +790,9 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -781,9 +800,13 @@ class RelPositionMultiheadAttention(nn.Module): ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) - attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) + attn_output_weights = ( + matrix_ac + matrix_bd + ) # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -817,9 +840,13 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -842,7 +869,9 @@ class ConvolutionModule(nn.Module): """ - def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: + def __init__( + self, channels: int, kernel_size: int, bias: bool = True + ) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py index 600aa9b39..74df04006 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py @@ -120,24 +120,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -212,7 +208,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -270,7 +267,9 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - layer_results, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + layer_results, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) encoder_out = layer_results[-1] hyps = [] @@ -286,7 +285,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -332,7 +334,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -405,7 +411,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -438,7 +446,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -481,7 +490,9 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -513,12 +524,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -541,12 +553,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -574,7 +587,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/export.py b/egs/librispeech/ASR/pruned_transducer_stateless6/export.py index 17f8614dc..cff9c7377 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/export.py @@ -51,7 +51,11 @@ import sentencepiece as spm import torch from train import get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import str2bool @@ -83,11 +87,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -118,7 +120,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) return parser @@ -157,7 +160,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -205,7 +209,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py b/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py index 86cf34877..21409287c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py @@ -21,10 +21,9 @@ import os from pathlib import Path import torch +from vq_utils import CodebookIndexExtractor from asr_datamodule import LibriSpeechAsrDataModule from hubert_xlarge import HubertXlargeFineTuned -from vq_utils import CodebookIndexExtractor - from icefall.utils import AttributeDict, str2bool diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py index b8440f90a..49b557814 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py @@ -23,6 +23,7 @@ from pathlib import Path from typing import Dict, List, Tuple import torch + from asr_datamodule import LibriSpeechAsrDataModule from hubert_xlarge import HubertXlargeFineTuned @@ -98,7 +99,9 @@ def decode_dataset( if batch_idx % 20 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -121,7 +124,9 @@ def save_results( ) test_set_wers[key] = wer - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + logging.info( + "Wrote detailed error stats to {}".format(errs_filename) + ) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.res_dir / f"wer-summary-{test_set_name}.txt" @@ -150,7 +155,9 @@ def main(): # reset some parameters needed by hubert. params.update(HubertXlargeFineTuned.get_params()) - params.res_dir = params.exp_dir / f"ctc_greedy_search-{params.teacher_model_id}" + params.res_dir = ( + params.exp_dir / f"ctc_greedy_search-{params.teacher_model_id}" + ) setup_logger(f"{params.res_dir}/log/log-ctc_greedy_search") logging.info("Decoding started") @@ -183,7 +190,9 @@ def main(): params=params, ) - save_results(params=params, test_set_name=test_set, results_dict=results_dict) + save_results( + params=params, test_set_name=test_set, results_dict=results_dict + ) logging.info("Done!") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py index 4f9417c9f..55ce7b00d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py @@ -22,7 +22,11 @@ from pathlib import Path from typing import Dict, List, Tuple import torch -from fairseq import checkpoint_utils, tasks, utils +from fairseq import ( + checkpoint_utils, + tasks, + utils, +) from fairseq.data.data_utils import post_process from omegaconf import OmegaConf @@ -47,7 +51,9 @@ def _load_hubert_model(params: AttributeDict): "data": str(params.hubert_model_dir), } ) - model_path = Path(params.hubert_model_dir) / (params.teacher_model_id + ".pt") + model_path = Path(params.hubert_model_dir) / ( + params.teacher_model_id + ".pt" + ) task = tasks.setup_task(cfg_task) processor = task.target_dictionary models, saved_cfg = checkpoint_utils.load_model_ensemble( @@ -145,7 +151,9 @@ class HubertXlargeFineTuned: supervisions = batch["supervisions"] num_samples = supervisions["num_samples"] B, T = features.shape - padding_mask = torch.arange(0, T).expand(B, T) > num_samples.reshape([-1, 1]) + padding_mask = torch.arange(0, T).expand(B, T) > num_samples.reshape( + [-1, 1] + ) padding_mask = padding_mask.to(self.params.device) features = features.to(self.params.device) @@ -155,7 +163,9 @@ class HubertXlargeFineTuned: features = features.transpose(1, 2) features = self.w2v_model.layer_norm(features) - padding_mask = self.w2v_model.forward_padding_mask(features, padding_mask) + padding_mask = self.w2v_model.forward_padding_mask( + features, padding_mask + ) if self.w2v_model.post_extract_proj is not None: features = self.w2v_model.post_extract_proj(features) @@ -202,7 +212,9 @@ class HubertXlargeFineTuned: toks = encoder_out.argmax(dim=-1) blank = 0 toks = [tok.unique_consecutive() for tok in toks] - hyps = [self.processor.string(tok[tok != blank].int().cpu()) for tok in toks] + hyps = [ + self.processor.string(tok[tok != blank].int().cpu()) for tok in toks + ] hyps = [post_process(hyp, "letter") for hyp in hyps] return hyps diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py index daadb70c9..7716d19cf 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py @@ -69,7 +69,9 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_speed=0.5 + ) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) from icefall import is_module_available @@ -178,7 +180,9 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) boundary[:, 2] = y_lens boundary[:, 3] = x_lens @@ -233,7 +237,9 @@ class Transducer(nn.Module): return (simple_loss, pruned_loss, codebook_loss) @staticmethod - def concat_successive_codebook_indexes(middle_layer_output, codebook_indexes): + def concat_successive_codebook_indexes( + middle_layer_output, codebook_indexes + ): # Output rate of hubert is 50 frames per second, # while that of current encoder is 25. # Following code handling two issues: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py index be54ff0ce..f717d85fb 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py @@ -101,7 +101,9 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def get_parser(): @@ -201,45 +203,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -570,7 +569,9 @@ def save_checkpoint( def extract_codebook_indexes(batch): cuts = batch["supervisions"]["cut"] # -100 is identical to ignore_value in CE loss computation. - cuts_pre_mixed = [c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts] + cuts_pre_mixed = [ + c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts + ] codebook_indexes, codebook_indexes_lens = collate_custom_field( cuts_pre_mixed, "codebook_indexes", pad_value=-100 ) @@ -603,7 +604,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -650,7 +655,9 @@ def compute_loss( # If the batch contains more than 10 utterances AND # if either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -663,9 +670,14 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss if is_training and params.enable_distillation: assert codebook_loss is not None loss += params.codebook_loss_scale * codebook_loss @@ -678,7 +690,9 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -859,7 +873,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -991,7 +1007,8 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py index 40f97f662..47cf2b14b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py @@ -68,7 +68,9 @@ class CodebookIndexExtractor: def init_dirs(self): # vq_dir is the root dir for quantization, containing: # training data, trained quantizer, and extracted codebook indexes - self.vq_dir = self.params.exp_dir / f"vq/{self.params.teacher_model_id}/" + self.vq_dir = ( + self.params.exp_dir / f"vq/{self.params.teacher_model_id}/" + ) self.vq_dir.mkdir(parents=True, exist_ok=True) # manifest_dir contains: @@ -206,7 +208,9 @@ class CodebookIndexExtractor: start = cur_offset % (data.shape[0] + 1 - B) end = start + B cur_offset += B - yield data[start:end, :].to(self.params.device).to(dtype=torch.float) + yield data[start:end, :].to(self.params.device).to( + dtype=torch.float + ) for x in minibatch_generator(train, repeat=True): trainer.step(x) @@ -223,11 +227,10 @@ class CodebookIndexExtractor: """ for subset in self.params.subsets: logging.info(f"About to split {subset}.") - ori_manifest = f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" - split_cmd = ( - "lhotse split" - f" {self.params.world_size} {ori_manifest} {self.manifest_dir}" + ori_manifest = ( + f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" ) + split_cmd = f"lhotse split {self.params.world_size} {ori_manifest} {self.manifest_dir}" os.system(f"{split_cmd}") def join_manifests(self): @@ -237,13 +240,16 @@ class CodebookIndexExtractor: logging.info("Start to join manifest files.") for subset in self.params.subsets: vq_manifest_path = ( - self.dst_manifest_dir / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" + self.dst_manifest_dir + / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" ) ori_manifest_path = ( - self.ori_manifest_dir / f"librispeech_cuts_train-{subset}.jsonl.gz" + self.ori_manifest_dir + / f"librispeech_cuts_train-{subset}.jsonl.gz" ) dst_vq_manifest_path = ( - self.dst_manifest_dir / f"librispeech_cuts_train-{subset}.jsonl.gz" + self.dst_manifest_dir + / f"librispeech_cuts_train-{subset}.jsonl.gz" ) cuts_vq = load_manifest(vq_manifest_path) cuts_ori = load_manifest(ori_manifest_path) @@ -263,7 +269,8 @@ class CodebookIndexExtractor: for subset in self.params.subsets: vq_manifests = f"{self.manifest_dir}/with_codebook_indexes-librispeech-cuts_train-{subset}*.jsonl.gz" dst_vq_manifest = ( - self.dst_manifest_dir / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" + self.dst_manifest_dir + / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" ) if 1 == self.params.world_size: merge_cmd = f"cp {vq_manifests} {dst_vq_manifest}" @@ -323,7 +330,9 @@ class CodebookIndexExtractor: def load_ori_dl(self, subset): if self.params.world_size == 1: - ori_manifest_path = f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" + ori_manifest_path = ( + f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" + ) else: ori_manifest_path = ( self.manifest_dir diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py index fa8144935..06c5863f1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py @@ -164,24 +164,20 @@ def get_parser(): "--avg", type=int, default=9, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -276,7 +272,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -396,7 +393,9 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] @@ -455,7 +454,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -586,7 +588,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -619,7 +623,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -674,7 +679,9 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -711,12 +718,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -739,12 +747,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -772,7 +781,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -799,7 +808,9 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py index 5f90e6375..712dc8ce1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py @@ -69,7 +69,7 @@ class Decoder(nn.Module): out_channels=decoder_dim, kernel_size=context_size, padding=0, - groups=decoder_dim // 4, # group size == 4 + groups=decoder_dim//4, # group size == 4 bias=False, ) @@ -91,7 +91,9 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) + embedding_out = F.pad( + embedding_out, pad=(self.context_size - 1, 0) + ) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py index 43ac658e5..5744ea3ea 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py @@ -129,24 +129,20 @@ def get_parser(): "--avg", type=int, default=9, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -180,7 +176,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) add_model_arguments(parser) @@ -218,12 +215,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -246,12 +244,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -279,7 +278,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -317,7 +316,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py index c94a34d58..e2405d5ef 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py @@ -69,12 +69,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) return parser @@ -95,9 +93,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -268,7 +267,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py index 3ddac2cf2..7d8de5afe 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py @@ -56,7 +56,9 @@ class Joiner(nn.Module): assert encoder_out.shape[:-1] == decoder_out.shape[:-1] if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) + logit = self.encoder_proj(encoder_out) + self.decoder_proj( + decoder_out + ) else: logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py index 0e59b0f2f..53cde6c6f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py @@ -15,15 +15,14 @@ # limitations under the License. -import random - import k2 import torch import torch.nn as nn +import random from encoder_interface import EncoderInterface -from scaling import penalize_abs_values_gt from icefall.utils import add_sos +from scaling import penalize_abs_values_gt class Transducer(nn.Module): @@ -66,8 +65,7 @@ class Transducer(nn.Module): self.joiner = joiner self.simple_am_proj = nn.Linear( - encoder_dim, - vocab_size, + encoder_dim, vocab_size, ) self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size) @@ -135,16 +133,18 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) boundary[:, 2] = y_lens boundary[:, 3] = x_lens lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - # if self.training and random.random() < 0.25: + #if self.training and random.random() < 0.25: # lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04) - # if self.training and random.random() < 0.25: + #if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) with torch.cuda.amp.autocast(enabled=False): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 460ac2c3e..bb8b0a0e3 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -14,17 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib -import logging -import random from collections import defaultdict -from typing import List, Optional, Tuple, Union - -import torch +from typing import List, Optional, Union, Tuple, List from lhotse.utils import fix_random_seed +import torch from scaling import ActivationBalancer +import random from torch import Tensor from torch.optim import Optimizer +import logging +import contextlib + class BatchedOptimizer(Optimizer): @@ -37,10 +37,11 @@ class BatchedOptimizer(Optimizer): Args: params: """ - def __init__(self, params, defaults): super(BatchedOptimizer, self).__init__(params, defaults) + + @contextlib.contextmanager def batched_params(self, param_group): """ @@ -72,9 +73,7 @@ class BatchedOptimizer(Optimizer): group: a parameter group, which is a list of parameters; should be one of self.groups. """ - batches = defaultdict( - list - ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter + batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter for p in param_group: key = (str(p.dtype), *p.shape) @@ -83,7 +82,7 @@ class BatchedOptimizer(Optimizer): stacked_params_dict = dict() # turn batches into a list, in deterministic order. - batches = [batches[key] for key in sorted(batches.keys())] + batches = [ batches[key] for key in sorted(batches.keys()) ] # pairs will contain pairs of (stacked_param, state), one for each batch # in `batches`. pairs = [] @@ -95,78 +94,77 @@ class BatchedOptimizer(Optimizer): # group. class Optimizer will take care of saving/loading state. state = self.state[p] p_stacked = torch.stack(batch) - grad = torch.stack( - [torch.zeros_like(p) if p.grad is None else p.grad for p in batch] - ) + grad = torch.stack([torch.zeros_like(p) if p.grad is None else p.grad for p in batch ]) p_stacked.grad = grad stacked_params_dict[key] = p_stacked pairs.append((p_stacked, state)) - yield pairs # <-- calling code will do the actual optimization here! + yield pairs # <-- calling code will do the actual optimization here! for ((stacked_params, _state), batch) in zip(pairs, batches): for i, p in enumerate(batch): # batch is list of Parameter p.copy_(stacked_params[i]) + class ScaledAdam(BatchedOptimizer): """ - Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update - proportional to the norm of that parameter; and also learn the scale of the parameter, - in log space, subject to upper and lower limits (as if we had factored each parameter as - param = underlying_param * log_scale.exp()) + Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update + proportional to the norm of that parameter; and also learn the scale of the parameter, + in log space, subject to upper and lower limits (as if we had factored each parameter as + param = underlying_param * log_scale.exp()) - Args: - params: The parameters or param_groups to optimize (like other Optimizer subclasses) - lr: The learning rate. We will typically use a learning rate schedule that starts - at 0.03 and decreases over time, i.e. much higher than other common - optimizers. - clipping_scale: (e.g. 2.0) - A scale for gradient-clipping: if specified, the normalized gradients - over the whole model will be clipped to have 2-norm equal to - `clipping_scale` times the median 2-norm over the most recent period - of `clipping_update_period` minibatches. By "normalized gradients", - we mean after multiplying by the rms parameter value for this tensor - [for non-scalars]; this is appropriate because our update is scaled - by this quantity. - betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. - Must satisfy 0 < beta <= beta2 < 1. - scalar_lr_scale: A scaling factor on the learning rate, that we use to update the - scale of each parameter tensor and scalar parameters of the mode.. - If each parameter were decomposed - as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale - would be a the scaling factor on the learning rate of p_scale. - eps: A general-purpose epsilon to prevent division by zero - param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be >= this value) - param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be <= this value) - scalar_max: Maximum absolute value for scalar parameters (applicable if your - model has any parameters with numel() == 1). - size_update_period: The periodicity, in steps, with which we update the size (scale) - of the parameter tensor. This is provided to save a little time - in the update. - clipping_update_period: if clipping_scale is specified, this is the period + Args: + params: The parameters or param_groups to optimize (like other Optimizer subclasses) + lr: The learning rate. We will typically use a learning rate schedule that starts + at 0.03 and decreases over time, i.e. much higher than other common + optimizers. + clipping_scale: (e.g. 2.0) + A scale for gradient-clipping: if specified, the normalized gradients + over the whole model will be clipped to have 2-norm equal to + `clipping_scale` times the median 2-norm over the most recent period + of `clipping_update_period` minibatches. By "normalized gradients", + we mean after multiplying by the rms parameter value for this tensor + [for non-scalars]; this is appropriate because our update is scaled + by this quantity. + betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. + Must satisfy 0 < beta <= beta2 < 1. + scalar_lr_scale: A scaling factor on the learning rate, that we use to update the + scale of each parameter tensor and scalar parameters of the mode.. + If each parameter were decomposed + as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale + would be a the scaling factor on the learning rate of p_scale. + eps: A general-purpose epsilon to prevent division by zero + param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be >= this value) + param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be <= this value) + scalar_max: Maximum absolute value for scalar parameters (applicable if your + model has any parameters with numel() == 1). + size_update_period: The periodicity, in steps, with which we update the size (scale) + of the parameter tensor. This is provided to save a little time + in the update. + clipping_update_period: if clipping_scale is specified, this is the period """ - def __init__( - self, - params, - lr=3e-02, - clipping_scale=None, - betas=(0.9, 0.98), - scalar_lr_scale=0.1, - eps=1.0e-08, - param_min_rms=1.0e-05, - param_max_rms=3.0, - scalar_max=10.0, - size_update_period=4, - clipping_update_period=100, + self, + params, + lr=3e-02, + clipping_scale=None, + betas=(0.9, 0.98), + scalar_lr_scale=0.1, + eps=1.0e-08, + param_min_rms=1.0e-05, + param_max_rms=3.0, + scalar_max=10.0, + size_update_period=4, + clipping_update_period=100, ): + defaults = dict( lr=lr, clipping_scale=clipping_scale, @@ -185,6 +183,7 @@ class ScaledAdam(BatchedOptimizer): def __setstate__(self, state): super(ScaledAdam, self).__setstate__(state) + @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. @@ -207,9 +206,7 @@ class ScaledAdam(BatchedOptimizer): # a regular parameter, and will have a .grad, but the 1st dim corresponds to # a stacking dim, it is not a real dim. - if ( - len(batches[0][1]) == 0 - ): # if len(first state) == 0: not yet initialized + if len(batches[0][1]) == 0: # if len(first state) == 0: not yet initialized clipping_scale = 1 else: clipping_scale = self._get_clipping_scale(group, batches) @@ -228,9 +225,13 @@ class ScaledAdam(BatchedOptimizer): self._step_one_batch(group, p, state, clipping_scale) + return loss - def _init_state(self, group: dict, p: Tensor, state: dict): + def _init_state(self, + group: dict, + p: Tensor, + state: dict): """ Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p is actually the batch dimension, corresponding to batched-together @@ -246,7 +247,7 @@ class ScaledAdam(BatchedOptimizer): state["step"] = 0 - kwargs = {"device": p.device, "dtype": p.dtype} + kwargs = {'device':p.device, 'dtype':p.dtype} # 'delta' implements conventional momentum. There are # several different kinds of update going on, so rather than @@ -254,30 +255,36 @@ class ScaledAdam(BatchedOptimizer): # parameter-change "delta", which combines all forms of # update. this is equivalent to how it's done in Adam, # except for the first few steps. - state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format) + state["delta"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) batch_size = p.shape[0] numel = p.numel() // batch_size numel = p.numel() + if numel > 1: # "param_rms" just periodically records the scalar root-mean-square value of # the parameter tensor. # it has a shape like (batch_size, 1, 1, 1, 1) - param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() + param_rms = (p**2).mean(dim=list(range(1, p.ndim)), + keepdim=True).sqrt() state["param_rms"] = param_rms state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) - state["scale_grads"] = torch.zeros( - size_update_period, *param_rms.shape, **kwargs - ) + state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape, + **kwargs) + # exp_avg_sq is the weighted sum of scaled gradients. as in Adam. - state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) - def _get_clipping_scale( - self, group: dict, pairs: List[Tuple[Tensor, dict]] - ) -> float: + def _get_clipping_scale(self, + group: dict, + pairs: List[Tuple[Tensor, dict]]) -> float: """ Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients by this amount before applying the rest of the update. @@ -307,67 +314,57 @@ class ScaledAdam(BatchedOptimizer): if p.numel() == p.shape[0]: # a batch of scalars tot_sumsq += (grad**2).sum() # sum() to change shape [1] to [] else: - tot_sumsq += ((grad * state["param_rms"]) ** 2).sum() + tot_sumsq += ((grad * state["param_rms"])**2).sum() tot_norm = tot_sumsq.sqrt() - if "model_norms" not in first_state: - first_state["model_norms"] = torch.zeros( - clipping_update_period, device=p.device - ) + if not "model_norms" in first_state: + first_state["model_norms"] = torch.zeros(clipping_update_period, + device=p.device) first_state["model_norms"][step % clipping_update_period] = tot_norm if step % clipping_update_period == 0: # Print some stats. # We don't reach here if step == 0 because we would have returned # above. - sorted_norms = first_state["model_norms"].sort()[0].to("cpu") + sorted_norms = first_state["model_norms"].sort()[0].to('cpu') quartiles = [] for n in range(0, 5): - index = min( - clipping_update_period - 1, - (clipping_update_period // 4) * n, - ) + index = min(clipping_update_period - 1, + (clipping_update_period // 4) * n) quartiles.append(sorted_norms[index].item()) median = quartiles[2] threshold = clipping_scale * median first_state["model_norm_threshold"] = threshold - percent_clipped = ( - first_state["num_clipped"] * 100.0 / clipping_update_period - if "num_clipped" in first_state - else 0.0 - ) + percent_clipped = (first_state["num_clipped"] * 100.0 / clipping_update_period + if "num_clipped" in first_state else 0.0) first_state["num_clipped"] = 0 - quartiles = " ".join(["%.3e" % x for x in quartiles]) - logging.info( - f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " - f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}" - ) + quartiles = ' '.join([ '%.3e' % x for x in quartiles ]) + logging.info(f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " + f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}") if step < clipping_update_period: return 1.0 # We have not yet estimated a norm to clip to. else: try: model_norm_threshold = first_state["model_norm_threshold"] - except KeyError: - logging.info( - "Warning: model_norm_threshold not in state: possibly " - "you changed config when restarting, adding clipping_scale option?" - ) + except: + logging.info("Warning: model_norm_threshold not in state: possibly " + "you changed config when restarting, adding clipping_scale option?") return 1.0 - ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item()) + ans = min(1.0,(model_norm_threshold / (tot_norm + 1.0e-20)).item()) if ans < 1.0: first_state["num_clipped"] += 1 if ans < 0.1: - logging.warn( - f"Scaling gradients by {ans}," - f" model_norm_threshold={model_norm_threshold}" - ) + logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}") return ans - def _step_one_batch( - self, group: dict, p: Tensor, state: dict, clipping_scale: float - ): + + def _step_one_batch(self, + group: dict, + p: Tensor, + state: dict, + clipping_scale: float): """ Do the step for one parameter, which is actually going to be a batch of `real` parameters, with dim 0 as the batch dim. @@ -394,18 +391,17 @@ class ScaledAdam(BatchedOptimizer): # Update the size/scale of p, and set param_rms scale_grads = state["scale_grads"] scale_grads[step % size_update_period] = (p * grad).sum( - dim=list(range(1, p.ndim)), keepdim=True - ) + dim=list(range(1, p.ndim)), keepdim=True) if step % size_update_period == size_update_period - 1: param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..) - param_rms.copy_( - (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() - ) + param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), + keepdim=True).sqrt()) if step > 0: # self._size_update() learns the overall scale on the # parameter, by shrinking or expanding it. self._size_update(group, scale_grads, p, state) + if numel == 1: # For parameters with 1 element we just use regular Adam. # Updates delta. @@ -415,21 +411,24 @@ class ScaledAdam(BatchedOptimizer): state["step"] = step + 1 - def _size_update( - self, group: dict, scale_grads: Tensor, p: Tensor, state: dict - ) -> None: - """ - Called only where p.numel() > 1, this updates the scale of the parameter. - If we imagine: p = underlying_param * scale.exp(), and we are doing - gradient descent on underlying param and on scale, this function does the update - on `scale`. - Args: - group: dict to look up configuration values - scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing - grads w.r.t. the scales. - p: The parameter to update - state: The state-dict of p + def _size_update(self, + group: dict, + scale_grads: Tensor, + p: Tensor, + state: dict) -> None: + """ + Called only where p.numel() > 1, this updates the scale of the parameter. + If we imagine: p = underlying_param * scale.exp(), and we are doing + gradient descent on underlying param and on scale, this function does the update + on `scale`. + + Args: + group: dict to look up configuration values + scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing + grads w.r.t. the scales. + p: The parameter to update + state: The state-dict of p """ param_rms = state["param_rms"] @@ -444,28 +443,25 @@ class ScaledAdam(BatchedOptimizer): size_update_period = scale_grads.shape[0] # correct beta2 for the size update period: we will have # faster decay at this level. - beta2_corr = beta2**size_update_period + beta2_corr = beta2 ** size_update_period scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..) scale_exp_avg_sq.mul_(beta2_corr).add_( - (scale_grads**2).mean(dim=0), # mean over dim `size_update_period` - alpha=1 - beta2_corr, - ) # shape is (batch_size, 1, 1, ...) + (scale_grads ** 2).mean(dim=0), # mean over dim `size_update_period` + alpha=1-beta2_corr) # shape is (batch_size, 1, 1, ...) # The 1st time we reach here is when size_step == 1. size_step = (step + 1) // size_update_period - bias_correction2 = 1 - beta2_corr**size_step + bias_correction2 = 1 - beta2_corr ** size_step # we don't bother with bias_correction1; this will help prevent divergence # at the start of training. denom = scale_exp_avg_sq.sqrt() + eps - scale_step = ( - -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom - ) + scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum(dim=0) / denom - is_too_small = param_rms < param_min_rms - is_too_large = param_rms > param_max_rms + is_too_small = (param_rms < param_min_rms) + is_too_large = (param_rms > param_max_rms) # when the param gets too small, just don't shrink it any further. scale_step.masked_fill_(is_too_small, 0.0) @@ -473,9 +469,13 @@ class ScaledAdam(BatchedOptimizer): scale_step.masked_fill_(is_too_large, -size_lr * size_update_period) delta = state["delta"] # the factor of (1-beta1) relates to momentum. - delta.add_(p * scale_step, alpha=(1 - beta1)) + delta.add_(p * scale_step, alpha=(1-beta1)) - def _step(self, group: dict, p: Tensor, state: dict): + + def _step(self, + group: dict, + p: Tensor, + state: dict): """ This function does the core update of self.step(), in the case where the members of the batch have more than 1 element. @@ -496,7 +496,8 @@ class ScaledAdam(BatchedOptimizer): step = state["step"] exp_avg_sq = state["exp_avg_sq"] - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, + value=(1-beta2)) this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0) bias_correction2 = 1 - beta2 ** (this_step + 1) @@ -508,13 +509,17 @@ class ScaledAdam(BatchedOptimizer): denom += eps grad = grad / denom - alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms) + alpha = -lr * (1-beta1) * state["param_rms"].clamp(min=param_min_rms) delta = state["delta"] delta.add_(grad * alpha) p.add_(delta) - def _step_scalar(self, group: dict, p: Tensor, state: dict): + + def _step_scalar(self, + group: dict, + p: Tensor, + state: dict): """ A simplified form of the core update for scalar tensors, where we cannot get a good estimate of the parameter rms. @@ -526,7 +531,8 @@ class ScaledAdam(BatchedOptimizer): grad = p.grad exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, + value=1-beta2) # bias_correction2 is like in Adam. Don't bother with bias_correction1; # slower update at the start will help stability anyway. @@ -534,11 +540,12 @@ class ScaledAdam(BatchedOptimizer): denom = (exp_avg_sq / bias_correction2).sqrt() + eps delta = state["delta"] - delta.add_(grad / denom, alpha=-lr * (1 - beta1)) + delta.add_(grad / denom, alpha=-lr*(1-beta1)) p.clamp_(min=-scalar_max, max=scalar_max) p.add_(delta) + class LRScheduler(object): """ Base-class for learning rate schedulers where the learning-rate depends on both the @@ -548,14 +555,18 @@ class LRScheduler(object): def __init__(self, optimizer: Optimizer, verbose: bool = False): # Attach optimizer if not isinstance(optimizer, Optimizer): - raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) + raise TypeError( + "{} is not an Optimizer".format(type(optimizer).__name__) + ) self.optimizer = optimizer self.verbose = verbose for group in optimizer.param_groups: group.setdefault("base_lr", group["lr"]) - self.base_lrs = [group["base_lr"] for group in optimizer.param_groups] + self.base_lrs = [ + group["base_lr"] for group in optimizer.param_groups + ] self.epoch = 0 self.batch = 0 @@ -669,15 +680,13 @@ class Eden(LRScheduler): def get_lr(self): factor = ( - (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 + (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 ) ** -0.25 * ( - ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 - ) - warmup_factor = ( - 1.0 - if self.batch >= self.warmup_batches - else 0.5 + 0.5 * (self.batch / self.warmup_batches) + ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) + ** -0.25 ) + warmup_factor = (1.0 if self.batch >= self.warmup_batches + else 0.5 + 0.5 * (self.batch / self.warmup_batches)) return [x * factor * warmup_factor for x in self.base_lrs] @@ -736,14 +745,13 @@ class Eve(Optimizer): parameters, if they fall below this we will stop applying weight decay. - .. _Adam: A Method for Stochastic Optimization: + .. _Adam\: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 .. _Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101 .. _On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ """ - def __init__( self, params, @@ -758,11 +766,17 @@ class Eve(Optimizer): if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + raise ValueError( + "Invalid beta parameter at index 0: {}".format(betas[0]) + ) if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + raise ValueError( + "Invalid beta parameter at index 1: {}".format(betas[1]) + ) if not 0 <= weight_decay <= 0.1: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) if not 0 < target_rms <= 10.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) defaults = dict( @@ -798,7 +812,9 @@ class Eve(Optimizer): # Perform optimization step grad = p.grad if grad.is_sparse: - raise RuntimeError("AdamW does not support sparse gradients") + raise RuntimeError( + "AdamW does not support sparse gradients" + ) state = self.state[p] @@ -825,7 +841,7 @@ class Eve(Optimizer): # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( + denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( group["eps"] ) @@ -836,31 +852,30 @@ class Eve(Optimizer): if p.numel() > 1: # avoid applying this weight-decay on "scaling factors" # (which are scalar). - is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) + is_above_target_rms = p.norm() > ( + target_rms * (p.numel() ** 0.5) + ) p.mul_(1 - (weight_decay * is_above_target_rms)) p.addcdiv_(exp_avg, denom, value=-step_size) if random.random() < 0.0005: - step = (exp_avg / denom) * step_size - logging.info( - f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}" - ) + step = (exp_avg/denom) * step_size + logging.info(f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}") + return loss def _test_scaled_adam(hidden_dim: int): import timeit - from scaling import ScaledLinear - E = 100 B = 4 T = 2 logging.info("in test_eve_cain") - # device = torch.device('cuda') - device = torch.device("cpu") + #device = torch.device('cuda') + device = torch.device('cpu') dtype = torch.float32 fix_random_seed(42) @@ -874,93 +889,79 @@ def _test_scaled_adam(hidden_dim: int): fix_random_seed(42) Linear = torch.nn.Linear if iter == 0 else ScaledLinear - m = torch.nn.Sequential( - Linear(E, hidden_dim), - torch.nn.PReLU(), - Linear(hidden_dim, hidden_dim), - torch.nn.PReLU(), - Linear(hidden_dim, E), - ).to(device) + m = torch.nn.Sequential(Linear(E, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, E), + ).to(device) - train_pairs = [ - ( - 100.0 - * torch.randn(B, T, E, device=device, dtype=dtype) - * input_magnitudes, - torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes, - ) - for _ in range(20) - ] + train_pairs = [ (100.0 * torch.randn(B, T, E, device=device, dtype=dtype) * input_magnitudes, + torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes) for _ in range(20) ] - if iter == 0: - optim = Eve(m.parameters(), lr=0.003) - elif iter == 1: - optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) + if iter == 0: optim = Eve(m.parameters(), lr=0.003) + elif iter == 1: optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) + start = timeit.default_timer() avg_loss = 0.0 for epoch in range(180): scheduler.step_epoch() - # if epoch == 100 and iter in [2,3]: + #if epoch == 100 and iter in [2,3]: # optim.reset_speedup() # check it doesn't crash. - # if epoch == 130: + #if epoch == 130: # opts = diagnostics.TensorDiagnosticOptions( # 2 ** 22 # ) # allow 4 megabytes per sub-module # diagnostic = diagnostics.attach_diagnostics(m, opts) - for n, (x, y) in enumerate(train_pairs): + + for n, (x,y) in enumerate(train_pairs): y_out = m(x) - loss = ((y_out - y) ** 2).mean() * 100.0 + loss = ((y_out - y)**2).mean() * 100.0 if epoch == 0 and n == 0: avg_loss = loss.item() else: avg_loss = 0.98 * avg_loss + 0.02 * loss.item() if n == 0 and epoch % 5 == 0: - # norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() - # norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() - # norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() - # norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() - # scale1 = '%.2e' % (m[0].weight_scale.exp().item()) - # scale1b = '%.2e' % (m[0].bias_scale.exp().item()) - # scale2 = '%.2e' % (m[2].weight_scale.exp().item()) - # scale2b = '%.2e' % (m[2].bias_scale.exp().item()) + #norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() + #norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() + #norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() + #norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() + #scale1 = '%.2e' % (m[0].weight_scale.exp().item()) + #scale1b = '%.2e' % (m[0].bias_scale.exp().item()) + #scale2 = '%.2e' % (m[2].weight_scale.exp().item()) + #scale2b = '%.2e' % (m[2].bias_scale.exp().item()) lr = scheduler.get_last_lr()[0] - logging.info( - f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss" - f" {avg_loss:.4g}, lr={lr:.4e}" - ) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} + logging.info(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}") #, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} loss.log().backward() optim.step() optim.zero_grad() scheduler.step_batch() - # diagnostic.print_diagnostics() + #diagnostic.print_diagnostics() stop = timeit.default_timer() logging.info(f"Iter={iter}, Time taken: {stop - start}") logging.info(f"last lr = {scheduler.get_last_lr()}") - # logging.info("state dict = ", scheduler.state_dict()) - # logging.info("optim state_dict = ", optim.state_dict()) + #logging.info("state dict = ", scheduler.state_dict()) + #logging.info("optim state_dict = ", optim.state_dict()) logging.info(f"input_magnitudes = {input_magnitudes}") logging.info(f"output_magnitudes = {output_magnitudes}") + if __name__ == "__main__": torch.set_num_threads(1) torch.set_num_interop_threads(1) logging.getLogger().setLevel(logging.INFO) import subprocess - - s = subprocess.check_output( - "git status -uno .; git log -1; git diff HEAD .", shell=True - ) + s = subprocess.check_output("git status -uno .; git log -1; git diff HEAD .", shell=True) logging.info(s) import sys - if len(sys.argv) > 1: hidden_dim = int(sys.argv[1]) else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py index 8b4d88871..7fe1e681a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py @@ -100,11 +100,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -129,12 +127,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -181,7 +177,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -212,9 +209,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -277,11 +275,15 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lengths + ) num_waves = encoder_out.size(0) hyps = [] @@ -353,7 +355,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 4040065e1..50cedba56 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -16,12 +16,12 @@ import collections -import logging -import random -from functools import reduce from itertools import repeat from typing import Optional, Tuple, Union +from functools import reduce +import logging +import random import torch import torch.nn as nn import torch.nn.functional as F @@ -32,24 +32,27 @@ from torch.nn import Embedding as ScaledEmbedding class ActivationBalancerFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - scale_factor: Tensor, - sign_factor: Optional[Tensor], - channel_dim: int, + ctx, + x: Tensor, + scale_factor: Tensor, + sign_factor: Optional[Tensor], + channel_dim: int, ) -> Tensor: if channel_dim < 0: channel_dim += x.ndim ctx.channel_dim = channel_dim - xgt0 = x > 0 + xgt0 = (x > 0) if sign_factor is None: ctx.save_for_backward(xgt0, scale_factor) else: ctx.save_for_backward(xgt0, scale_factor, sign_factor) return x + @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: + def backward( + ctx, x_grad: Tensor + ) -> Tuple[Tensor, None, None, None]: if len(ctx.saved_tensors) == 3: xgt0, scale_factor, sign_factor = ctx.saved_tensors for _ in range(ctx.channel_dim, x_grad.ndim - 1): @@ -62,22 +65,14 @@ class ActivationBalancerFunction(torch.autograd.Function): scale_factor = scale_factor.unsqueeze(-1) factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5) neg_delta_grad = x_grad.abs() * factor - return ( - x_grad - neg_delta_grad, - None, - None, - None, - ) + return x_grad - neg_delta_grad, None, None, None, - -def _compute_scale_factor( - x: Tensor, - channel_dim: int, - min_abs: float, - max_abs: float, - gain_factor: float, - max_factor: float, -) -> Tensor: +def _compute_scale_factor(x: Tensor, + channel_dim: int, + min_abs: float, + max_abs: float, + gain_factor: float, + max_factor: float) -> Tensor: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] @@ -88,76 +83,71 @@ def _compute_scale_factor( else: # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if # x_abs)_mean , min_abs. - below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp( - min=0, max=max_factor - ) + below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(min=0, max=max_factor) - above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp( - min=0, max=max_factor - ) + above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(min=0, max=max_factor) return below_threshold - above_threshold - -def _compute_sign_factor( - x: Tensor, - channel_dim: int, - min_positive: float, - max_positive: float, - gain_factor: float, - max_factor: float, -) -> Tensor: +def _compute_sign_factor(x: Tensor, + channel_dim: int, + min_positive: float, + max_positive: float, + gain_factor: float, + max_factor: float) -> Tensor: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] - proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims) + proportion_positive = torch.mean((x > 0).to(torch.float32), + dim=sum_dims) if min_positive == 0.0: factor1 = 0.0 else: # 0 if proportion_positive >= min_positive, else can be # as large as max_factor. - factor1 = ( - (min_positive - proportion_positive) * (gain_factor / min_positive) - ).clamp_(min=0, max=max_factor) + factor1 = ((min_positive - proportion_positive) * + (gain_factor / min_positive)).clamp_(min=0, max=max_factor) if max_positive == 1.0: factor2 = 0.0 else: # 0 if self.proportion_positive <= max_positive, else can be # as large as -max_factor. - factor2 = ( - (proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive)) - ).clamp_(min=0, max=max_factor) + factor2 = ((proportion_positive - max_positive) * + (gain_factor / (1.0 - max_positive))).clamp_(min=0, max=max_factor) sign_factor = factor1 - factor2 # require min_positive != 0 or max_positive != 1: assert not isinstance(sign_factor, float) return sign_factor + class ActivationScaleBalancerFunction(torch.autograd.Function): """ This object is used in class ActivationBalancer when the user specified min_positive=0, max_positive=1, so there are no constraints on the signs of the activations and only the absolute value has a constraint. """ - @staticmethod def forward( - ctx, - x: Tensor, - sign_factor: Tensor, - scale_factor: Tensor, - channel_dim: int, + ctx, + x: Tensor, + sign_factor: Tensor, + scale_factor: Tensor, + channel_dim: int, ) -> Tensor: if channel_dim < 0: channel_dim += x.ndim ctx.channel_dim = channel_dim - xgt0 = x > 0 + xgt0 = (x > 0) ctx.save_for_backward(xgt0, sign_factor, scale_factor) return x + @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: + def backward( + ctx, x_grad: Tensor + ) -> Tuple[Tensor, None, None, None]: xgt0, sign_factor, scale_factor = ctx.saved_tensors for _ in range(ctx.channel_dim, x_grad.ndim - 1): sign_factor = sign_factor.unsqueeze(-1) @@ -165,24 +155,18 @@ class ActivationScaleBalancerFunction(torch.autograd.Function): factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5) neg_delta_grad = x_grad.abs() * factor - return ( - x_grad - neg_delta_grad, - None, - None, - None, - ) + return x_grad - neg_delta_grad, None, None, None, class RandomClampFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - min: Optional[float], - max: Optional[float], - prob: float, - reflect: float, - ) -> Tensor: + ctx, + x: Tensor, + min: Optional[float], + max: Optional[float], + prob: float, + reflect: float) -> Tensor: x_clamped = torch.clamp(x, min=min, max=max) mask = torch.rand_like(x) < prob ans = torch.where(mask, x_clamped, x) @@ -195,32 +179,30 @@ class RandomClampFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None, None]: - (is_same,) = ctx.saved_tensors + is_same, = ctx.saved_tensors x_grad = ans_grad * is_same.to(ans_grad.dtype) reflect = ctx.reflect - if reflect != 0.0: + if reflect != 0.0: x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect) return x_grad, None, None, None, None - -def random_clamp( - x: Tensor, - min: Optional[float] = None, - max: Optional[float] = None, - prob: float = 0.5, - reflect: float = 0.0, -): +def random_clamp(x: Tensor, + min: Optional[float] = None, + max: Optional[float] = None, + prob: float = 0.5, + reflect: float = 0.0): return RandomClampFunction.apply(x, min, max, prob, reflect) -def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor: +def random_cast_to_half(x: Tensor, + min_abs: float = 5.0e-06) -> Tensor: """ A randomized way of casting a floating point value to half precision. """ if x.dtype == torch.float16: return x x_abs = x.abs() - is_too_small = x_abs < min_abs + is_too_small = (x_abs < min_abs) # for elements where is_too_small is true, random_val will contain +-min_abs with # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations, # for those elements]. @@ -233,7 +215,6 @@ class RandomGradFunction(torch.autograd.Function): Does nothing in forward pass; in backward pass, gets rid of very small grads using randomized approach that preserves expectations (intended to reduce roundoff). """ - @staticmethod def forward(ctx, x: Tensor, min_abs: float) -> Tensor: ctx.min_abs = min_abs @@ -242,37 +223,35 @@ class RandomGradFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]: if ans_grad.dtype == torch.float16: - return ( - random_cast_to_half(ans_grad.to(torch.float32), min_abs=ctx.min_abs), - None, - ) + return random_cast_to_half(ans_grad.to(torch.float32), + min_abs=ctx.min_abs), None else: return ans_grad, None - class RandomGrad(torch.nn.Module): """ Gets rid of very small gradients using an expectation-preserving method, intended to increase accuracy of training when using amp (automatic mixed precision) """ - - def __init__(self, min_abs: float = 5.0e-06): + def __init__(self, + min_abs: float = 5.0e-06): super(RandomGrad, self).__init__() self.min_abs = min_abs - def forward(self, x: Tensor): + def forward(self, + x: Tensor): if torch.jit.is_scripting() or not self.training: return x else: return RandomGradFunction.apply(x, self.min_abs) + class SoftmaxFunction(torch.autograd.Function): """ Tries to handle half-precision derivatives in a randomized way that should be more accurate for training than the default behavior. """ - @staticmethod def forward(ctx, x: Tensor, dim: int): ans = x.softmax(dim=dim) @@ -288,7 +267,7 @@ class SoftmaxFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor): - (ans,) = ctx.saved_tensors + ans, = ctx.saved_tensors with torch.cuda.amp.autocast(enabled=False): ans_grad = ans_grad.to(torch.float32) ans = ans.to(torch.float32) @@ -297,7 +276,9 @@ class SoftmaxFunction(torch.autograd.Function): return x_grad, None -def softmax(x: Tensor, dim: int): + +def softmax(x: Tensor, + dim: int): if torch.jit.is_scripting(): return x.softmax(dim) @@ -307,18 +288,20 @@ def softmax(x: Tensor, dim: int): class MaxEigLimiterFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - coeffs: Tensor, - direction: Tensor, - channel_dim: int, - grad_scale: float, - ) -> Tensor: + ctx, + x: Tensor, + coeffs: Tensor, + direction: Tensor, + channel_dim: int, + grad_scale: float) -> Tensor: ctx.channel_dim = channel_dim ctx.grad_scale = grad_scale - ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach()) + ctx.save_for_backward(x.detach(), + coeffs.detach(), + direction.detach()) return x + @staticmethod def backward(ctx, x_grad, *args): with torch.enable_grad(): @@ -328,20 +311,15 @@ class MaxEigLimiterFunction(torch.autograd.Function): x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) new_direction.requires_grad = False x = x - x.mean(dim=0) - x_var = (x**2).mean() + x_var = (x ** 2).mean() x_residual = x - coeffs * new_direction - x_residual_var = (x_residual**2).mean() + x_residual_var = (x_residual ** 2).mean() # `variance_proportion` is the proportion of the variance accounted for # by the top eigen-direction. This is to be minimized. variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) variance_proportion.backward() x_orig_grad = x_orig.grad - x_extra_grad = ( - x_orig.grad - * ctx.grad_scale - * x_grad.norm() - / (x_orig_grad.norm() + 1.0e-20) - ) + x_extra_grad = x_orig.grad * ctx.grad_scale * x_grad.norm() / (x_orig_grad.norm() + 1.0e-20) return x_grad + x_extra_grad.detach(), None, None, None, None @@ -407,12 +385,15 @@ class BasicNorm(torch.nn.Module): # region if it happens to exit it. eps = eps.clamp(min=self.eps_min, max=self.eps_max) scales = ( - torch.mean(x**2, dim=self.channel_dim, keepdim=True) + eps.exp() + torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp() ) ** -0.5 return x * scales -def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: + +def ScaledLinear(*args, + initial_scale: float = 1.0, + **kwargs ) -> nn.Linear: """ Behaves like a constructor of a modified version of nn.Linear that gives an easy way to set the default initial parameter scale. @@ -431,11 +412,16 @@ def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) + torch.nn.init.uniform_(ans.bias, + -0.1 * initial_scale, + 0.1 * initial_scale) return ans -def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d: + +def ScaledConv1d(*args, + initial_scale: float = 1.0, + **kwargs ) -> nn.Conv1d: """ Behaves like a constructor of a modified version of nn.Conv1d that gives an easy way to set the default initial parameter scale. @@ -454,10 +440,13 @@ def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d: with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) + torch.nn.init.uniform_(ans.bias, + -0.1 * initial_scale, + 0.1 * initial_scale) return ans + class ActivationBalancer(torch.nn.Module): """ Modifies the backpropped derivatives of a function to try to encourage, for @@ -497,19 +486,18 @@ class ActivationBalancer(torch.nn.Module): from doing it at the same time. Early in training we may use higher probabilities than this; it will decay to this value. """ - def __init__( - self, - num_channels: int, - channel_dim: int, - min_positive: float = 0.05, - max_positive: float = 0.95, - max_factor: float = 0.04, - sign_gain_factor: float = 0.01, - scale_gain_factor: float = 0.02, - min_abs: float = 0.2, - max_abs: float = 100.0, - min_prob: float = 0.1, + self, + num_channels: int, + channel_dim: int, + min_positive: float = 0.05, + max_positive: float = 0.95, + max_factor: float = 0.04, + sign_gain_factor: float = 0.01, + scale_gain_factor: float = 0.02, + min_abs: float = 0.2, + max_abs: float = 100.0, + min_prob: float = 0.1, ): super(ActivationBalancer, self).__init__() self.num_channels = num_channels @@ -527,7 +515,9 @@ class ActivationBalancer(torch.nn.Module): # We occasionally sync this to a tensor called `count`, that exists to # make sure it is synced to disk when we load and save the model. self.cpu_count = 0 - self.register_buffer("count", torch.tensor(0, dtype=torch.int64)) + self.register_buffer('count', torch.tensor(0, dtype=torch.int64)) + + def forward(self, x: Tensor) -> Tensor: if torch.jit.is_scripting() or not x.requires_grad: @@ -545,35 +535,26 @@ class ActivationBalancer(torch.nn.Module): # the prob of doing some work exponentially decreases from 0.5 till it hits # a floor at min_prob (==0.1, by default) - prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0))) + prob = max(self.min_prob, 0.5 ** (1 + (count/4000.0))) if random.random() < prob: sign_gain_factor = 0.5 if self.min_positive != 0.0 or self.max_positive != 1.0: - sign_factor = _compute_sign_factor( - x, - self.channel_dim, - self.min_positive, - self.max_positive, - gain_factor=self.sign_gain_factor / prob, - max_factor=self.max_factor, - ) + sign_factor = _compute_sign_factor(x, self.channel_dim, + self.min_positive, self.max_positive, + gain_factor=self.sign_gain_factor / prob, + max_factor=self.max_factor) else: sign_factor = None - scale_factor = _compute_scale_factor( - x, - self.channel_dim, - min_abs=self.min_abs, - max_abs=self.max_abs, - gain_factor=self.scale_gain_factor / prob, - max_factor=self.max_factor, - ) + + scale_factor = _compute_scale_factor(x, self.channel_dim, + min_abs=self.min_abs, + max_abs=self.max_abs, + gain_factor=self.scale_gain_factor / prob, + max_factor=self.max_factor) return ActivationBalancerFunction.apply( - x, - scale_factor, - sign_factor, - self.channel_dim, + x, scale_factor, sign_factor, self.channel_dim, ) else: return _no_op(x) @@ -613,12 +594,13 @@ def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. else: (batch, dim, dim) = x.shape x = x.reshape(batch, dim * dim) - x = x[:, :: dim + 1] + x = x[:, ::dim+1] assert x.shape == (batch, dim) return x -def _whitening_metric(x: Tensor, num_groups: int): +def _whitening_metric(x: Tensor, + num_groups: int): """ Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of of the centered feature covariance are the same within each group's covariance matrix @@ -648,21 +630,19 @@ def _whitening_metric(x: Tensor, num_groups: int): # the following expression is what we'd get if we took the matrix product # of each covariance and measured the mean of its trace, i.e. # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). - x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group) + x_covarsq_mean_diag = (x_covar ** 2).sum() / (num_groups * channels_per_group) # this metric will be >= 1.0; the larger it is, the less 'white' the data was. - metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20) + metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20) return metric class WhiteningPenaltyFunction(torch.autograd.Function): @staticmethod - def forward( - ctx, - x: Tensor, - num_groups: int, - whitening_limit: float, - grad_scale: float, - ) -> Tensor: + def forward(ctx, + x: Tensor, + num_groups: int, + whitening_limit: float, + grad_scale: float) -> Tensor: ctx.save_for_backward(x) ctx.num_groups = num_groups ctx.whitening_limit = whitening_limit @@ -670,8 +650,9 @@ class WhiteningPenaltyFunction(torch.autograd.Function): return x @staticmethod - def backward(ctx, x_grad: Tensor): - (x_orig,) = ctx.saved_tensors + def backward(ctx, + x_grad: Tensor): + x_orig, = ctx.saved_tensors with torch.enable_grad(): with torch.cuda.amp.autocast(enabled=False): x_detached = x_orig.to(torch.float32).detach() @@ -680,29 +661,25 @@ class WhiteningPenaltyFunction(torch.autograd.Function): metric = _whitening_metric(x_detached, ctx.num_groups) if random.random() < 0.005 or __name__ == "__main__": - logging.info( - f"Whitening: num_groups={ctx.num_groups}," - f" num_channels={x_orig.shape[-1]}," - f" metric={metric.item():.2f} vs. limit={ctx.whitening_limit}" - ) + logging.info(f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, " + f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}") (metric - ctx.whitening_limit).relu().backward() penalty_grad = x_detached.grad - scale = ctx.grad_scale * ( - x_grad.to(torch.float32).norm() / (penalty_grad.norm() + 1.0e-20) - ) + scale = ctx.grad_scale * (x_grad.to(torch.float32).norm() / + (penalty_grad.norm() + 1.0e-20)) penalty_grad = penalty_grad * scale return x_grad + penalty_grad.to(x_grad.dtype), None, None, None + class Whiten(nn.Module): def __init__( - self, - num_groups: int, - whitening_limit: float, - prob: Union[float, Tuple[float, float]], - grad_scale: float, - ): + self, + num_groups: int, + whitening_limit: float, + prob: Union[float, Tuple[float,float]], + grad_scale: float): """ Args: num_groups: the number of groups to divide the channel dim into before @@ -737,7 +714,8 @@ class Whiten(nn.Module): self.grad_scale = grad_scale - def forward(self, x: Tensor) -> Tensor: + def forward(self, + x: Tensor) -> Tensor: """ In the forward pass, this function just returns the input unmodified. In the backward pass, it will modify the gradients to ensure that the @@ -757,21 +735,19 @@ class Whiten(nn.Module): if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0: return _no_op(x) else: - if hasattr(self, "min_prob") and random.random() < 0.25: + if hasattr(self, 'min_prob') and random.random() < 0.25: # occasionally switch between min_prob and max_prob, based on whether # we are above or below the threshold. - if ( - _whitening_metric(x.to(torch.float32), self.num_groups) - > self.whitening_limit - ): + if _whitening_metric(x.to(torch.float32), self.num_groups) > self.whitening_limit: # there would be a change to the grad. self.prob = self.max_prob else: self.prob = self.min_prob - return WhiteningPenaltyFunction.apply( - x, self.num_groups, self.whitening_limit, self.grad_scale - ) + return WhiteningPenaltyFunction.apply(x, + self.num_groups, + self.whitening_limit, + self.grad_scale) class WithLoss(torch.autograd.Function): @@ -779,14 +755,11 @@ class WithLoss(torch.autograd.Function): def forward(ctx, x: Tensor, y: Tensor): ctx.y_shape = y.shape return x - @staticmethod def backward(ctx, ans_grad: Tensor): - return ans_grad, torch.ones( - ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device - ) - - + return ans_grad, torch.ones(ctx.y_shape, + dtype=ans_grad.dtype, + device=ans_grad.device) def with_loss(x, y): if torch.jit.is_scripting(): return x @@ -795,7 +768,7 @@ def with_loss(x, y): def _no_op(x: Tensor) -> Tensor: - if torch.jit.is_scripting(): + if (torch.jit.is_scripting()): return x else: # a no-op function that will have a node in the autograd graph, @@ -810,7 +783,6 @@ class Identity(torch.nn.Module): def forward(self, x): return _no_op(x) - class MaxEig(torch.nn.Module): """ Modifies the backpropped derivatives of a function to try to discourage @@ -831,14 +803,13 @@ class MaxEig(torch.nn.Module): scale: determines the scale with which we modify the gradients, relative to the existing / unmodified gradients """ - def __init__( - self, - num_channels: int, - channel_dim: int, - max_var_per_eig: float = 0.2, - min_prob: float = 0.01, - scale: float = 0.01, + self, + num_channels: int, + channel_dim: int, + max_var_per_eig: float = 0.2, + min_prob: float = 0.01, + scale: float = 0.01, ): super(MaxEig, self).__init__() self.num_channels = num_channels @@ -854,7 +825,7 @@ class MaxEig(torch.nn.Module): # random parameters unchanged for comparison direction = torch.arange(num_channels).to(torch.float) direction = direction / direction.norm() - self.register_buffer("max_eig_direction", direction) + self.register_buffer('max_eig_direction', direction) self.min_prob = min_prob # cur_prob is the current probability we'll use to apply the ActivationBalancer. @@ -862,12 +833,12 @@ class MaxEig(torch.nn.Module): # active. self.cur_prob = 1.0 + + def forward(self, x: Tensor) -> Tensor: - if ( - torch.jit.is_scripting() - or self.max_var_per_eig <= 0 - or random.random() > self.cur_prob - ): + if (torch.jit.is_scripting() or + self.max_var_per_eig <= 0 or + random.random() > self.cur_prob): return _no_op(x) with torch.cuda.amp.autocast(enabled=False): @@ -877,9 +848,7 @@ class MaxEig(torch.nn.Module): with torch.no_grad(): x = x.transpose(self.channel_dim, -1).reshape(-1, self.num_channels) x = x - x.mean(dim=0) - new_direction, coeffs = self._find_direction_coeffs( - x, self.max_eig_direction - ) + new_direction, coeffs = self._find_direction_coeffs(x, self.max_eig_direction) x_var = (x**2).mean() x_residual = x - coeffs * new_direction x_residual_var = (x_residual**2).mean() @@ -892,10 +861,7 @@ class MaxEig(torch.nn.Module): self._set_direction(0.1 * self.max_eig_direction + new_direction) if random.random() < 0.01 or __name__ == "__main__": - logging.info( - f"variance_proportion = {variance_proportion.item()}," - f" shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}" - ) + logging.info(f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}") if variance_proportion >= self.max_var_per_eig: # The constraint is active. Note, we should quite rarely @@ -903,16 +869,17 @@ class MaxEig(torch.nn.Module): # starting to diverge, should this constraint be active. cur_prob = self.cur_prob self.cur_prob = 1.0 # next time, do the update with probability 1.0. - return MaxEigLimiterFunction.apply( - orig_x, coeffs, new_direction, self.channel_dim, self.scale - ) + return MaxEigLimiterFunction.apply(orig_x, coeffs, new_direction, + self.channel_dim, self.scale) else: # let self.cur_prob exponentially approach self.min_prob, as # long as the constraint is inactive. self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob return orig_x - def _set_direction(self, direction: Tensor): + + def _set_direction(self, + direction: Tensor): """ Sets self.max_eig_direction to a normalized version of `direction` """ @@ -922,39 +889,40 @@ class MaxEig(torch.nn.Module): if direction_sum - direction_sum == 0: # no inf/nan self.max_eig_direction[:] = direction else: - logging.info( - f"Warning: sum of direction in MaxEig is {direction_sum}, " - "num_channels={self.num_channels}, channel_dim={self.channel_dim}" - ) + logging.info(f"Warning: sum of direction in MaxEig is {direction_sum}, " + "num_channels={self.num_channels}, channel_dim={self.channel_dim}") - def _find_direction_coeffs( - self, x: Tensor, prev_direction: Tensor - ) -> Tuple[Tensor, Tensor, Tensor]: - """ - Figure out (an approximation to) the proportion of the variance of a set of - feature vectors that can be attributed to the top eigen-direction. - Args: - x: a Tensor of shape (num_frames, num_channels), with num_frames > 1. - prev_direction: a Tensor of shape (num_channels,), that is our previous estimate - of the top eigen-direction, or a random direction if this is the first - iteration. Does not have to be normalized, but should be nonzero. - Returns: (cur_direction, coeffs), where: - cur_direction: a Tensor of shape (num_channels,) that is the current - estimate of the top eigen-direction. - coeffs: a Tensor of shape (num_frames, 1) that minimizes, or - approximately minimizes, (x - coeffs * cur_direction).norm() + def _find_direction_coeffs(self, + x: Tensor, + prev_direction: Tensor) -> Tuple[Tensor, Tensor, Tensor]: """ + Figure out (an approximation to) the proportion of the variance of a set of + feature vectors that can be attributed to the top eigen-direction. + Args: + x: a Tensor of shape (num_frames, num_channels), with num_frames > 1. + prev_direction: a Tensor of shape (num_channels,), that is our previous estimate + of the top eigen-direction, or a random direction if this is the first + iteration. Does not have to be normalized, but should be nonzero. + + Returns: (cur_direction, coeffs), where: + cur_direction: a Tensor of shape (num_channels,) that is the current + estimate of the top eigen-direction. + coeffs: a Tensor of shape (num_frames, 1) that minimizes, or + approximately minimizes, (x - coeffs * cur_direction).norm() + """ (num_frames, num_channels) = x.shape assert num_channels > 1 and num_frames > 1 assert prev_direction.shape == (num_channels,) # `coeffs` are the coefficients of `prev_direction` in x. # actually represent the coeffs up to a constant positive factor. coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10 - cur_direction = (x * coeffs).sum(dim=0) / ((coeffs**2).sum() + 1.0e-20) + cur_direction = (x * coeffs).sum(dim=0) / ((coeffs ** 2).sum() + 1.0e-20) return cur_direction, coeffs + + class DoubleSwishFunction(torch.autograd.Function): """ double_swish(x) = x * torch.sigmoid(x-1) @@ -982,7 +950,7 @@ class DoubleSwishFunction(torch.autograd.Function): y = x * s if requires_grad: - deriv = y * (1 - s) + s + deriv = (y * (1 - s) + s) # notes on derivative of x * sigmoid(x - 1): # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund @@ -991,9 +959,7 @@ class DoubleSwishFunction(torch.autograd.Function): # floors), should be expectation-preserving. floor = -0.043637 ceil = 1.2 - d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like( - deriv - ) + d_scaled = ((deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv)) if __name__ == "__main__": # for self-testing only. assert d_scaled.min() >= 0.0 @@ -1006,12 +972,12 @@ class DoubleSwishFunction(torch.autograd.Function): @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: - (d,) = ctx.saved_tensors + d, = ctx.saved_tensors # the same constants as used in forward pass. floor = -0.043637 ceil = 1.2 - d = d * ((ceil - floor) / 255.0) + floor - return y_grad * d + d = (d * ((ceil - floor) / 255.0) + floor) + return (y_grad * d) class DoubleSwish(torch.nn.Module): @@ -1024,6 +990,7 @@ class DoubleSwish(torch.nn.Module): return DoubleSwishFunction.apply(x) + def _test_max_eig(): for proportion in [0.1, 0.5, 10.0]: logging.info(f"proportion = {proportion}") @@ -1035,9 +1002,11 @@ def _test_max_eig(): x.requires_grad = True num_channels = 128 - m = MaxEig( - num_channels, 1, 0.5, scale=0.1 # channel_dim # max_var_per_eig - ) # grad_scale + m = MaxEig(num_channels, + 1, # channel_dim + 0.5, # max_var_per_eig + scale=0.1) # grad_scale + for _ in range(4): y = m(x) @@ -1062,9 +1031,11 @@ def _test_whiten(): x.requires_grad = True num_channels = 128 - m = Whiten( - 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit, - ) # grad_scale + m = Whiten(1, # num_groups + 5.0, # whitening_limit, + prob=1.0, + grad_scale=0.1) # grad_scale + for _ in range(4): y = m(x) @@ -1078,6 +1049,7 @@ def _test_whiten(): assert not torch.allclose(x.grad, y_grad) + def _test_activation_balancer_sign(): probs = torch.arange(0, 1, 0.01) N = 1000 @@ -1105,7 +1077,9 @@ def _test_activation_balancer_sign(): def _test_activation_balancer_magnitude(): magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( + -1 + ) x = x.detach() x.requires_grad = True m = ActivationBalancer( @@ -1137,8 +1111,8 @@ def _test_basic_norm(): y = m(x) assert y.shape == x.shape - x_rms = (x**2).mean().sqrt() - y_rms = (y**2).mean().sqrt() + x_rms = (x ** 2).mean().sqrt() + y_rms = (y ** 2).mean().sqrt() print("x rms = ", x_rms) print("y rms = ", y_rms) assert y_rms < x_rms @@ -1150,27 +1124,30 @@ def _test_double_swish_deriv(): x.requires_grad = True m = DoubleSwish() - tol = (1.2 - (-0.043637)) / 255.0 + tol = ((1.2-(-0.043637))/255.0) torch.autograd.gradcheck(m, x, atol=tol) + # for self-test. x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 x.requires_grad = True y = m(x) + def _test_softmax(): a = torch.randn(2, 10, dtype=torch.float64) b = a.clone() a.requires_grad = True b.requires_grad = True - a.softmax(dim=1)[:, 0].sum().backward() + a.softmax(dim=1)[:,0].sum().backward() print("a grad = ", a.grad) - softmax(b, dim=1)[:, 0].sum().backward() + softmax(b, dim=1)[:,0].sum().backward() print("b grad = ", b.grad) assert torch.allclose(a.grad, b.grad) + if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py index 46e775285..8d357b15f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py @@ -26,7 +26,11 @@ from typing import List import torch import torch.nn as nn -from scaling import ActivationBalancer, BasicNorm, Whiten +from scaling import ( + ActivationBalancer, + BasicNorm, + Whiten, +) class NonScaledNorm(nn.Module): @@ -71,10 +75,12 @@ def get_submodule(model, target): mod: torch.nn.Module = model for item in atoms: if not hasattr(mod, item): - raise AttributeError(mod._get_name() + " has no attribute `" + item + "`") + raise AttributeError( + mod._get_name() + " has no " "attribute `" + item + "`" + ) mod = getattr(mod, item) if not isinstance(mod, torch.nn.Module): - raise AttributeError("`" + item + "` is not an nn.Module") + raise AttributeError("`" + item + "` is not " "an nn.Module") return mod diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 7f9526104..3f27736b3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -84,7 +84,9 @@ from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: @@ -122,10 +124,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--encoder-dims", type=str, default="384,384,384,384,384", - help=( - "Embedding dimension in the 2 blocks of zipformer encoder layers, comma" - " separated" - ), + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", ) parser.add_argument( @@ -140,11 +139,9 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--encoder-unmasked-dims", type=str, default="256,256,256,256,256", - help=( - "Unmasked dimensions in the encoders, relates to augmentation during" - " training. Must be <= each of encoder_dims. Empirically, less than 256" - " seems to make performance worse." - ), + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", ) parser.add_argument( @@ -272,45 +269,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -652,7 +646,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -699,7 +697,9 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -870,7 +870,9 @@ def train_one_epoch( # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + if cur_grad_scale < 1.0 or ( + cur_grad_scale < 8.0 and batch_idx % 400 == 0 + ): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") @@ -888,7 +890,11 @@ def train_one_epoch( f"batch {batch_idx}, loss[{loss_info}], " f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + + ( + f"grad_scale: {scaler._scale.item()}" + if params.use_fp16 + else "" + ) ) if tb_writer is not None: @@ -899,7 +905,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if params.use_fp16: tb_writer.add_scalar( "train/grad_scale", @@ -907,7 +915,10 @@ def train_one_epoch( params.batch_idx_train, ) - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + if ( + batch_idx % params.valid_interval == 0 + and not params.print_diagnostics + ): logging.info("Computing validation loss") valid_info = compute_validation_loss( params=params, @@ -919,8 +930,7 @@ def train_one_epoch( model.train() logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") logging.info( - "Maximum memory allocated so far is" - f" {torch.cuda.max_memory_allocated()//1000000}MB" + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" ) if tb_writer is not None: valid_info.write_summary( @@ -999,7 +1009,9 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) - optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0) + optimizer = ScaledAdam( + model.parameters(), lr=params.base_lr, clipping_scale=2.0 + ) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) @@ -1017,7 +1029,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 2 ** 22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1042,7 +1054,8 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False @@ -1216,8 +1229,7 @@ def scan_pessimistic_batches_for_oom( display_and_save_batch(batch, params=params, sp=sp) raise logging.info( - "Maximum memory allocated so far is" - f" {torch.cuda.max_memory_allocated()//1000000}MB" + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index fcd9858cd..023dec97d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -16,35 +16,32 @@ # limitations under the License. import copy -import itertools -import logging import math -import random import warnings +import itertools from typing import List, Optional, Tuple, Union - +import logging import torch +import random from encoder_interface import EncoderInterface -from scaling import ( - ScaledLinear, # not as in other dirs.. just scales down initial parameter values. -) from scaling import ( ActivationBalancer, BasicNorm, - DoubleSwish, - Identity, MaxEig, + DoubleSwish, ScaledConv1d, + ScaledLinear, # not as in other dirs.. just scales down initial parameter values. Whiten, + Identity, _diag, - penalize_abs_values_gt, random_clamp, + penalize_abs_values_gt, softmax, ) from torch import Tensor, nn -from icefall.dist import get_rank from icefall.utils import make_pad_mask +from icefall.dist import get_rank class Zipformer(EncoderInterface): @@ -92,7 +89,7 @@ class Zipformer(EncoderInterface): self.batch_count = 0 self.warmup_end = warmup_batches - for u, d in zip(encoder_unmasked_dims, encoder_dims): + for u,d in zip(encoder_unmasked_dims, encoder_dims): assert u <= d, (u, d) # self.encoder_embed converts the input of shape (N, T, num_features) @@ -100,9 +97,9 @@ class Zipformer(EncoderInterface): # That is, it does two things simultaneously: # (1) subsampling: T -> (T - 7)//2 # (2) embedding: num_features -> encoder_dims - self.encoder_embed = Conv2dSubsampling( - num_features, encoder_dims[0], dropout=dropout - ) + self.encoder_embed = Conv2dSubsampling(num_features, encoder_dims[0], + dropout=dropout) + # each one will be ZipformerEncoder or DownsampledZipformerEncoder encoders = [] @@ -126,13 +123,13 @@ class Zipformer(EncoderInterface): num_encoder_layers[i], dropout, warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), - warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), + warmup_end=warmup_batches * (i + 2) / (num_encoders + 1) ) if zipformer_downsampling_factors[i] != 1: encoder = DownsampledZipformerEncoder( encoder, - input_dim=encoder_dims[i - 1] if i > 0 else encoder_dims[0], + input_dim=encoder_dims[i-1] if i > 0 else encoder_dims[0], output_dim=encoder_dims[i], downsample=zipformer_downsampling_factors[i], ) @@ -142,11 +139,10 @@ class Zipformer(EncoderInterface): # initializes self.skip_layers and self.skip_modules self._init_skip_modules() - self.downsample_output = AttentionDownsample( - encoder_dims[-1], - encoder_dims[-1], - downsample=output_downsampling_factor, - ) + self.downsample_output = AttentionDownsample(encoder_dims[-1], + encoder_dims[-1], + downsample=output_downsampling_factor) + def _get_layer_skip_dropout_prob(self): if not self.training: @@ -170,33 +166,27 @@ class Zipformer(EncoderInterface): skip_modules = [] z = self.zipformer_downsampling_factors for i in range(len(z)): - if i <= 1 or z[i - 1] <= z[i]: + if i <= 1 or z[i-1] <= z[i]: skip_layers.append(None) skip_modules.append(SimpleCombinerIdentity()) else: # TEMP - for j in range(i - 2, -1, -1): + for j in range(i-2, -1, -1): if z[j] <= z[i] or j == 0: # TEMP logging statement. - logging.info( - f"At encoder stack {i}, which has" - f" downsampling_factor={z[i]}, we will combine the outputs" - f" of layers {j} and {i-1}, with" - f" downsampling_factors={z[j]} and {z[i-1]}." - ) + logging.info(f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will " + f"combine the outputs of layers {j} and {i-1}, with downsampling_factors={z[j]} and {z[i-1]}.") skip_layers.append(j) - skip_modules.append( - SimpleCombiner( - self.encoder_dims[j], - self.encoder_dims[i - 1], - min_weight=(0.0, 0.25), - ) - ) + skip_modules.append(SimpleCombiner(self.encoder_dims[j], + self.encoder_dims[i-1], + min_weight=(0.0,0.25))) break self.skip_layers = skip_layers self.skip_modules = nn.ModuleList(skip_modules) - def get_feature_masks(self, x: torch.Tensor) -> List[float]: + def get_feature_masks( + self, + x: torch.Tensor) -> List[float]: # Note: The actual return type is Union[List[float], List[Tensor]], # but to make torch.jit.script() work, we use List[float] """ @@ -216,56 +206,46 @@ class Zipformer(EncoderInterface): """ num_encoders = len(self.encoder_dims) if torch.jit.is_scripting() or not self.training: - return [1.0] * num_encoders + return [ 1.0 ] * num_encoders (num_frames0, batch_size, _encoder_dims0) = x.shape - assert self.encoder_dims[0] == _encoder_dims0, ( - self.encoder_dims, - _encoder_dims0, - ) + + assert self.encoder_dims[0] == _encoder_dims0, (self.encoder_dims, _encoder_dims0) max_downsampling_factor = max(self.zipformer_downsampling_factors) - num_frames_max = num_frames0 + max_downsampling_factor - 1 + num_frames_max = (num_frames0 + max_downsampling_factor - 1) + feature_mask_dropout_prob = 0.15 # frame_mask_max shape: (num_frames_max, batch_size, 1) - frame_mask_max = ( - torch.rand(num_frames_max, batch_size, 1, device=x.device) - > feature_mask_dropout_prob - ).to(x.dtype) + frame_mask_max = (torch.rand(num_frames_max, batch_size, 1, + device=x.device) > + feature_mask_dropout_prob).to(x.dtype) feature_masks = [] for i in range(num_encoders): ds = self.zipformer_downsampling_factors[i] - upsample_factor = max_downsampling_factor // ds + upsample_factor = (max_downsampling_factor // ds) - frame_mask = ( - frame_mask_max.unsqueeze(1) - .expand(num_frames_max, upsample_factor, batch_size, 1) - .reshape(num_frames_max * upsample_factor, batch_size, 1) - ) + frame_mask = (frame_mask_max.unsqueeze(1).expand(num_frames_max, upsample_factor, + batch_size, 1) + .reshape(num_frames_max * upsample_factor, batch_size, 1)) num_frames = (num_frames0 + ds - 1) // ds frame_mask = frame_mask[:num_frames] - feature_mask = torch.ones( - num_frames, - batch_size, - self.encoder_dims[i], - dtype=x.dtype, - device=x.device, - ) + feature_mask = torch.ones(num_frames, batch_size, self.encoder_dims[i], + dtype=x.dtype, device=x.device) u = self.encoder_unmasked_dims[i] feature_mask[:, :, u:] *= frame_mask feature_masks.append(feature_mask) return feature_masks + def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, + self, x: torch.Tensor, x_lens: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -285,19 +265,13 @@ class Zipformer(EncoderInterface): x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) lengths = (x_lens - 7) >> 1 - assert x.size(0) == lengths.max().item(), ( - x.shape, - lengths, - lengths.max(), - ) + assert x.size(0) == lengths.max().item(), (x.shape, lengths, lengths.max()) mask = make_pad_mask(lengths) outputs = [] feature_masks = self.get_feature_masks(x) - for i, (module, skip_module) in enumerate( - zip(self.encoders, self.skip_modules) - ): + for i, (module, skip_module) in enumerate(zip(self.encoders, self.skip_modules)): ds = self.zipformer_downsampling_factors[i] k = self.skip_layers[i] if isinstance(k, int): @@ -306,11 +280,9 @@ class Zipformer(EncoderInterface): x = skip_module(outputs[k], x) elif (not self.training) or random.random() > layer_skip_dropout_prob: x = skip_module(outputs[k], x) - x = module( - x, - feature_mask=feature_masks[i], - src_key_padding_mask=None if mask is None else mask[..., ::ds], - ) + x = module(x, + feature_mask=feature_masks[i], + src_key_padding_mask=None if mask is None else mask[...,::ds]) outputs.append(x) x = self.downsample_output(x) @@ -340,16 +312,15 @@ class ZipformerEncoderLayer(nn.Module): >>> pos_emb = torch.rand(32, 19, 512) >>> out = encoder_layer(src, pos_emb) """ - def __init__( - self, - d_model: int, - attention_dim: int, - nhead: int, - feedforward_dim: int = 2048, - dropout: float = 0.1, - cnn_module_kernel: int = 31, - pos_dim: int = 4, + self, + d_model: int, + attention_dim: int, + nhead: int, + feedforward_dim: int = 2048, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + pos_dim: int = 4, ) -> None: super(ZipformerEncoderLayer, self).__init__() @@ -359,24 +330,29 @@ class ZipformerEncoderLayer(nn.Module): self.batch_count = 0 self.self_attn = RelPositionMultiheadAttention( - d_model, - attention_dim, - nhead, - pos_dim, - dropout=0.0, + d_model, attention_dim, nhead, pos_dim, dropout=0.0, ) self.pooling = PoolingModule(d_model) - self.feed_forward1 = FeedforwardModule(d_model, feedforward_dim, dropout) + self.feed_forward1 = FeedforwardModule(d_model, + feedforward_dim, + dropout) - self.feed_forward2 = FeedforwardModule(d_model, feedforward_dim, dropout) + self.feed_forward2 = FeedforwardModule(d_model, + feedforward_dim, + dropout) - self.feed_forward3 = FeedforwardModule(d_model, feedforward_dim, dropout) + self.feed_forward3 = FeedforwardModule(d_model, + feedforward_dim, + dropout) - self.conv_module1 = ConvolutionModule(d_model, cnn_module_kernel) - self.conv_module2 = ConvolutionModule(d_model, cnn_module_kernel) + self.conv_module1 = ConvolutionModule(d_model, + cnn_module_kernel) + + self.conv_module2 = ConvolutionModule(d_model, + cnn_module_kernel) self.norm_final = BasicNorm(d_model) @@ -384,18 +360,14 @@ class ZipformerEncoderLayer(nn.Module): # try to ensure the output is close to zero-mean (or at least, zero-median). self.balancer = ActivationBalancer( - d_model, - channel_dim=-1, - min_positive=0.45, - max_positive=0.55, + d_model, channel_dim=-1, + min_positive=0.45, max_positive=0.55, max_abs=6.0, ) - self.whiten = Whiten( - num_groups=1, - whitening_limit=5.0, - prob=(0.025, 0.25), - grad_scale=0.01, - ) + self.whiten = Whiten(num_groups=1, + whitening_limit=5.0, + prob=(0.025, 0.25), + grad_scale=0.01) def get_bypass_scale(self): if torch.jit.is_scripting() or not self.training: @@ -410,9 +382,8 @@ class ZipformerEncoderLayer(nn.Module): if self.batch_count > warmup_period: clamp_min = final_clamp_min else: - clamp_min = initial_clamp_min - (self.batch_count / warmup_period) * ( - initial_clamp_min - final_clamp_min - ) + clamp_min = (initial_clamp_min - + (self.batch_count / warmup_period) * (initial_clamp_min - final_clamp_min)) return self.bypass_scale.clamp(min=clamp_min, max=1.0) def get_dynamic_dropout_rate(self): @@ -427,9 +398,8 @@ class ZipformerEncoderLayer(nn.Module): if self.batch_count > warmup_period: return final_dropout_rate else: - return initial_dropout_rate - ( - initial_dropout_rate * final_dropout_rate - ) * (self.batch_count / warmup_period) + return (initial_dropout_rate - + (initial_dropout_rate * final_dropout_rate) * (self.batch_count / warmup_period)) def forward( self, @@ -538,14 +508,13 @@ class ZipformerEncoder(nn.Module): >>> src = torch.rand(10, 32, 512) >>> out = zipformer_encoder(src) """ - def __init__( - self, - encoder_layer: nn.Module, - num_layers: int, - dropout: float, - warmup_begin: float, - warmup_end: float, + self, + encoder_layer: nn.Module, + num_layers: int, + dropout: float, + warmup_begin: float, + warmup_end: float ) -> None: super().__init__() # will be written to, see set_batch_count() Note: in inference time this @@ -559,7 +528,8 @@ class ZipformerEncoder(nn.Module): # so that we can keep this consistent across worker tasks (for efficiency). self.module_seed = torch.randint(0, 1000, ()).item() - self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model, dropout) + self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model, + dropout) self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for i in range(num_layers)] @@ -568,13 +538,15 @@ class ZipformerEncoder(nn.Module): assert 0 <= warmup_begin <= warmup_end, (warmup_begin, warmup_end) - delta = (1.0 / num_layers) * (warmup_end - warmup_begin) + + delta = (1. / num_layers) * (warmup_end - warmup_begin) cur_begin = warmup_begin for i in range(num_layers): self.layers[i].warmup_begin = cur_begin cur_begin += delta self.layers[i].warmup_end = cur_begin + def get_layers_to_drop(self, rnd_seed: int): ans = set() if not self.training: @@ -607,14 +579,12 @@ class ZipformerEncoder(nn.Module): # linearly interpolate t = (batch_count - layer_warmup_begin) / layer_warmup_end assert 0.0 <= t < 1.001, t - return initial_layerdrop_prob + t * ( - final_layerdrop_prob - initial_layerdrop_prob - ) + return initial_layerdrop_prob + t * (final_layerdrop_prob - initial_layerdrop_prob) shared_rng = random.Random(batch_count + self.module_seed) independent_rng = random.Random(rnd_seed) - layerdrop_probs = [get_layerdrop_prob(i) for i in range(num_layers)] + layerdrop_probs = [ get_layerdrop_prob(i) for i in range(num_layers) ] tot = sum(layerdrop_probs) # Instead of drawing the samples independently, we first randomly decide # how many layers to drop out, using the same random number generator between @@ -634,13 +604,11 @@ class ZipformerEncoder(nn.Module): if len(ans) == num_to_drop: break if shared_rng.random() < 0.005 or __name__ == "__main__": - logging.info( - f"warmup_begin={self.warmup_begin:.1f}," - f" warmup_end={self.warmup_end:.1f}, batch_count={batch_count:.1f}," - f" num_to_drop={num_to_drop}, layers_to_drop={ans}" - ) + logging.info(f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, " + f"batch_count={batch_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}") return ans + def forward( self, src: Tensor, @@ -671,6 +639,7 @@ class ZipformerEncoder(nn.Module): pos_emb = self.encoder_pos(src) output = src + if torch.jit.is_scripting(): layers_to_drop = [] else: @@ -701,31 +670,28 @@ class DownsampledZipformerEncoder(nn.Module): after convolutional downsampling, and then upsampled again at the output, and combined with the origin input, so that the output has the same shape as the input. """ - - def __init__( - self, - encoder: nn.Module, - input_dim: int, - output_dim: int, - downsample: int, - ): + def __init__(self, + encoder: nn.Module, + input_dim: int, + output_dim: int, + downsample: int): super(DownsampledZipformerEncoder, self).__init__() self.downsample_factor = downsample self.downsample = AttentionDownsample(input_dim, output_dim, downsample) self.encoder = encoder self.upsample = SimpleUpsample(output_dim, downsample) - self.out_combiner = SimpleCombiner( - input_dim, output_dim, min_weight=(0.0, 0.25) - ) + self.out_combiner = SimpleCombiner(input_dim, + output_dim, + min_weight=(0.0, 0.25)) - def forward( - self, - src: Tensor, - # Note: the type of feature_mask should be Unino[float, Tensor], - # but to make torch.jit.script() happ, we use float here - feature_mask: float = 1.0, - mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, + + def forward(self, + src: Tensor, + # Note: the type of feature_mask should be Unino[float, Tensor], + # but to make torch.jit.script() happ, we use float here + feature_mask: float = 1.0, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: r"""Downsample, go through encoder, upsample. @@ -752,43 +718,42 @@ class DownsampledZipformerEncoder(nn.Module): src = self.downsample(src) ds = self.downsample_factor if mask is not None: - mask = mask[::ds, ::ds] + mask = mask[::ds,::ds] src = self.encoder( - src, - feature_mask=feature_mask, - mask=mask, - src_key_padding_mask=mask, + src, feature_mask=feature_mask, mask=mask, src_key_padding_mask=mask, ) src = self.upsample(src) # remove any extra frames that are not a multiple of downsample_factor - src = src[: src_orig.shape[0]] + src = src[:src_orig.shape[0]] return self.out_combiner(src_orig, src) - class AttentionDownsample(torch.nn.Module): """ Does downsampling with attention, by weighted sum, and a projection.. """ - - def __init__(self, in_channels: int, out_channels: int, downsample: int): + def __init__(self, + in_channels: int, + out_channels: int, + downsample: int): """ Require out_channels > in_channels. """ super(AttentionDownsample, self).__init__() - self.query = nn.Parameter(torch.randn(in_channels) * (in_channels**-0.5)) + self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5)) # fill in the extra dimensions with a projection of the input if out_channels > in_channels: - self.extra_proj = nn.Linear( - in_channels * downsample, out_channels - in_channels, bias=False - ) + self.extra_proj = nn.Linear(in_channels * downsample, + out_channels - in_channels, + bias=False) else: self.extra_proj = None self.downsample = downsample - def forward(self, src: Tensor) -> Tensor: + def forward(self, + src: Tensor) -> Tensor: """ x: (seq_len, batch_size, in_channels) Returns a tensor of shape @@ -802,14 +767,16 @@ class AttentionDownsample(torch.nn.Module): if seq_len != d_seq_len * ds: # right-pad src, repeating the last element. pad = d_seq_len * ds - seq_len - src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) + src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2]) src = torch.cat((src, src_extra), dim=0) assert src.shape[0] == d_seq_len * ds, (src.shape[0], d_seq_len, ds) src = src.reshape(d_seq_len, ds, batch_size, in_channels) scores = (src * self.query).sum(dim=-1, keepdim=True) - scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04) + scores = penalize_abs_values_gt(scores, + limit=10.0, + penalty=1.0e-04) weights = scores.softmax(dim=1) @@ -828,12 +795,14 @@ class SimpleUpsample(torch.nn.Module): A very simple form of upsampling that mostly just repeats the input, but also adds a position-specific bias. """ - - def __init__(self, num_channels: int, upsample: int): + def __init__(self, + num_channels: int, + upsample: int): super(SimpleUpsample, self).__init__() self.bias = nn.Parameter(torch.randn(upsample, num_channels) * 0.01) - def forward(self, src: Tensor) -> Tensor: + def forward(self, + src: Tensor) -> Tensor: """ x: (seq_len, batch_size, num_channels) Returns a tensor of shape @@ -846,7 +815,6 @@ class SimpleUpsample(torch.nn.Module): src = src.reshape(seq_len * upsample, batch_size, num_channels) return src - class SimpleCombinerIdentity(nn.Module): def __init__(self, *args, **kwargs): super().__init__() @@ -854,7 +822,6 @@ class SimpleCombinerIdentity(nn.Module): def forward(self, src1: Tensor, src2: Tensor) -> Tensor: return src1 - class SimpleCombiner(torch.nn.Module): """ A very simple way of combining 2 vectors of 2 different dims, via a @@ -864,14 +831,18 @@ class SimpleCombiner(torch.nn.Module): dim2: the dimension of the second input, e.g. 384. The output will have the same dimension as dim2. """ - - def __init__(self, dim1: int, dim2: int, min_weight: Tuple[float] = (0.0, 0.0)): + def __init__(self, + dim1: int, + dim2: int, + min_weight: Tuple[float] = (0., 0.)): super(SimpleCombiner, self).__init__() assert dim2 >= dim1, (dim2, dim1) self.weight1 = nn.Parameter(torch.zeros(())) self.min_weight = min_weight - def forward(self, src1: Tensor, src2: Tensor) -> Tensor: + def forward(self, + src1: Tensor, + src2: Tensor) -> Tensor: """ src1: (*, dim1) src2: (*, dim2) @@ -882,14 +853,10 @@ class SimpleCombiner(torch.nn.Module): weight1 = self.weight1 if not torch.jit.is_scripting(): - if ( - self.training - and random.random() < 0.25 - and self.min_weight != (0.0, 0.0) - ): - weight1 = weight1.clamp( - min=self.min_weight[0], max=1.0 - self.min_weight[1] - ) + if self.training and random.random() < 0.25 and self.min_weight != (0., 0.): + weight1 = weight1.clamp(min=self.min_weight[0], + max=1.0-self.min_weight[1]) + src1 = src1 * weight1 src2 = src2 * (1.0 - weight1) @@ -902,9 +869,12 @@ class SimpleCombiner(torch.nn.Module): else: src1 = src1[:src2_dim] + return src1 + src2 + + class RelPositionalEncoding(torch.nn.Module): """Relative positional encoding module. @@ -918,7 +888,9 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: """Construct a PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -933,7 +905,9 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(0) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vecotr and `j` means the @@ -981,6 +955,7 @@ class RelPositionalEncoding(torch.nn.Module): return self.dropout(pos_emb) + class RelPositionMultiheadAttention(nn.Module): r"""Multi-Head Attention layer with relative position encoding @@ -1017,46 +992,34 @@ class RelPositionMultiheadAttention(nn.Module): self.head_dim = attention_dim // num_heads self.pos_dim = pos_dim assert self.head_dim % 2 == 0, self.head_dim - assert self.head_dim * num_heads == attention_dim, ( - self.head_dim, - num_heads, - attention_dim, - ) + assert ( + self.head_dim * num_heads == attention_dim + ), (self.head_dim, num_heads, attention_dim) # the initial_scale is supposed to take over the "scaling" factor of # head_dim ** -0.5, dividing it between the query and key. - in_proj_dim = ( - 2 * attention_dim - + attention_dim // 2 # query, key - + pos_dim * num_heads # value - ) # positional encoding query + in_proj_dim = (2 * attention_dim + # query, key + attention_dim // 2 + # value + pos_dim * num_heads) # positional encoding query - self.in_proj = ScaledLinear( - embed_dim, - in_proj_dim, - bias=True, - initial_scale=self.head_dim**-0.25, - ) + self.in_proj = ScaledLinear(embed_dim, in_proj_dim, bias=True, + initial_scale=self.head_dim**-0.25) # self.whiten_values is applied on the values in forward(); # it just copies the keys but prevents low-rank distribution by modifying grads. - self.whiten_values = Whiten( - num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025, - ) - self.whiten_keys = Whiten( - num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025, - ) + self.whiten_values = Whiten(num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025) + self.whiten_keys = Whiten(num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025) + # linear transformation for positional encoding. - self.linear_pos = ScaledLinear( - embed_dim, num_heads * pos_dim, bias=False, initial_scale=0.05 - ) + self.linear_pos = ScaledLinear(embed_dim, num_heads * pos_dim, bias=False, + initial_scale=0.05) # the following are for diagnosics only, see --print-diagnostics option. # they only copy their inputs. @@ -1068,16 +1031,14 @@ class RelPositionMultiheadAttention(nn.Module): ) self.in_proj2 = nn.Linear(embed_dim, attention_dim // 2, bias=False) - self.out_proj2 = ScaledLinear( - attention_dim // 2, embed_dim, bias=True, initial_scale=0.05 - ) + self.out_proj2 = ScaledLinear(attention_dim // 2, embed_dim, bias=True, + initial_scale=0.05) # self.whiten_values2 is applied on the values in forward2() - self.whiten_values2 = Whiten( - num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025, - ) + self.whiten_values2 = Whiten(num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025) + def forward( self, @@ -1137,6 +1098,7 @@ class RelPositionMultiheadAttention(nn.Module): ) return x, weights + def multi_head_attention_forward( self, x_proj: Tensor, @@ -1194,24 +1156,26 @@ class RelPositionMultiheadAttention(nn.Module): head_dim = attention_dim // num_heads pos_dim = self.pos_dim # positional-encoding dim per head - assert head_dim * num_heads == attention_dim, ( - f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}," - f" {attention_dim}" - ) + assert ( + head_dim * num_heads == attention_dim + ), f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}, {attention_dim}" + # self-attention - q = x_proj[..., 0:attention_dim] - k = x_proj[..., attention_dim : 2 * attention_dim] + q = x_proj[...,0:attention_dim] + k = x_proj[...,attention_dim:2*attention_dim] value_dim = attention_dim // 2 - v = x_proj[..., 2 * attention_dim : 2 * attention_dim + value_dim] + v = x_proj[...,2*attention_dim:2*attention_dim+value_dim] # p is the position-encoding query, its dimension is num_heads*pos_dim.. - p = x_proj[..., 2 * attention_dim + value_dim :] + p = x_proj[...,2*attention_dim+value_dim:] + k = self.whiten_keys(k) # does nothing in the forward pass. v = self.whiten_values(v) # does nothing in the forward pass. q = self.copy_query(q) # for diagnostics only, does nothing. p = self.copy_pos_query(p) # for diagnostics only, does nothing. + if attn_mask is not None: assert ( attn_mask.dtype == torch.float32 @@ -1231,25 +1195,33 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, seq_len, seq_len]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, seq_len, seq_len, ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor" - " instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -1258,6 +1230,7 @@ class RelPositionMultiheadAttention(nn.Module): k = k.reshape(seq_len, bsz, num_heads, head_dim) v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1) + if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz, "{} == {}".format( key_padding_mask.size(0), bsz @@ -1266,10 +1239,13 @@ class RelPositionMultiheadAttention(nn.Module): key_padding_mask.size(1), seq_len ) + + q = q.permute(1, 2, 0, 3) # (batch, head, time1, head_dim) p = p.permute(1, 2, 0, 3) # (batch, head, time1, pos_dim) k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) + seq_len2 = 2 * seq_len - 1 pos = pos.reshape(1, seq_len2, num_heads, pos_dim).permute(0, 2, 3, 1) # pos shape now: (batch, head, pos_dim, seq_len2) @@ -1280,16 +1256,13 @@ class RelPositionMultiheadAttention(nn.Module): # the following .as_strided() expression converts the last axis of pos_weights from relative # to absolute position. I don't know whether I might have got the time-offsets backwards or # not, but let this code define which way round it is supposed to be. - pos_weights = pos_weights.as_strided( - (bsz, num_heads, seq_len, seq_len), - ( - pos_weights.stride(0), - pos_weights.stride(1), - pos_weights.stride(2) - pos_weights.stride(3), - pos_weights.stride(3), - ), - storage_offset=pos_weights.stride(3) * (seq_len - 1), - ) + pos_weights = pos_weights.as_strided((bsz, num_heads, seq_len, seq_len), + (pos_weights.stride(0), + pos_weights.stride(1), + pos_weights.stride(2)-pos_weights.stride(3), + pos_weights.stride(3)), + storage_offset=pos_weights.stride(3) * (seq_len - 1)) + # caution: they are really scores at this point. attn_output_weights = torch.matmul(q, k) + pos_weights @@ -1302,9 +1275,10 @@ class RelPositionMultiheadAttention(nn.Module): # this mechanism instead of, say, a limit on entropy, because once the entropy # gets very small gradients through the softmax can become very small, and # some mechanisms like that become ineffective. - attn_output_weights = penalize_abs_values_gt( - attn_output_weights, limit=25.0, penalty=1.0e-04 - ) + attn_output_weights = penalize_abs_values_gt(attn_output_weights, + limit=25.0, + penalty=1.0e-04) + # attn_output_weights: (batch, head, time1, time2) attn_output_weights = attn_output_weights.view( @@ -1346,20 +1320,20 @@ class RelPositionMultiheadAttention(nn.Module): ) attn_output = torch.bmm(attn_output_weights, v) - assert list(attn_output.size()) == [ - bsz * num_heads, - seq_len, - head_dim // 2, - ] + assert list(attn_output.size()) == [bsz * num_heads, seq_len, + head_dim // 2] attn_output = ( attn_output.transpose(0, 1) .contiguous() .view(seq_len, bsz, attention_dim // 2) ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias + ) return attn_output, attn_output_weights + def forward2( self, x: Tensor, @@ -1398,7 +1372,11 @@ class RelPositionMultiheadAttention(nn.Module): # returned value is of shape (seq_len, bsz, embed_dim), like x. return self.out_proj2(attn_output) - def _print_attn_stats(self, attn_weights: Tensor, attn_output: Tensor): + + def _print_attn_stats( + self, + attn_weights: Tensor, + attn_output: Tensor): # attn_weights: (batch_size * num_heads, seq_len, seq_len) # attn_output: (bsz * num_heads, seq_len, head_dim) (n, seq_len, head_dim) = attn_output.shape @@ -1409,50 +1387,39 @@ class RelPositionMultiheadAttention(nn.Module): with torch.cuda.amp.autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_output = attn_output.to(torch.float32) - attn_weights_entropy = ( - -((attn_weights + 1.0e-20).log() * attn_weights) - .sum(dim=-1) - .reshape(bsz, num_heads, seq_len) - .mean(dim=(0, 2)) - ) + attn_weights_entropy = -((attn_weights + 1.0e-20).log() * attn_weights).sum( + dim=-1).reshape(bsz, num_heads, seq_len).mean(dim=(0,2)) attn_output = attn_output.reshape(bsz, num_heads, seq_len, head_dim) - attn_output = attn_output.permute(1, 0, 2, 3).reshape( - num_heads, bsz * seq_len, head_dim - ) + attn_output = attn_output.permute(1, 0, 2, 3).reshape(num_heads, bsz * seq_len, head_dim) attn_output_mean = attn_output.mean(dim=1, keepdim=True) attn_output = attn_output - attn_output_mean - attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / ( - bsz * seq_len - ) + attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / (bsz * seq_len) # attn_covar: (num_heads, head_dim, head_dim) - # eigs, _ = torch.symeig(attn_covar) - # logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") + #eigs, _ = torch.symeig(attn_covar) + #logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") attn_covar = _diag(attn_covar).mean(dim=1) # (num_heads,) embed_dim = self.in_proj2.weight.shape[1] - in_proj_covar = ( - self.in_proj2.weight.reshape(num_heads, head_dim, embed_dim) ** 2 - ).mean(dim=(1, 2)) - out_proj_covar = ( - self.out_proj2.weight.reshape(embed_dim, num_heads, head_dim) ** 2 - ).mean(dim=(0, 2)) - logging.info( - f"attn_weights_entropy = {attn_weights_entropy}," - f" covar={attn_covar}, in_proj_covar={in_proj_covar}," - f" out_proj_covar={out_proj_covar}" - ) + in_proj_covar = (self.in_proj2.weight.reshape(num_heads, head_dim, embed_dim) ** 2).mean(dim=(1,2)) + out_proj_covar = (self.out_proj2.weight.reshape(embed_dim, num_heads, head_dim) ** 2).mean(dim=(0,2)) + logging.info(f"attn_weights_entropy = {attn_weights_entropy}, covar={attn_covar}, in_proj_covar={in_proj_covar}, out_proj_covar={out_proj_covar}") + + class PoolingModule(nn.Module): """ Averages the input over the time dimension and project with a square matrix. """ - - def __init__(self, d_model: int): + def __init__(self, + d_model: int): super().__init__() - self.proj = ScaledLinear(d_model, d_model, initial_scale=0.1, bias=False) + self.proj = ScaledLinear(d_model, d_model, + initial_scale=0.1, bias=False) - def forward(self, x: Tensor, key_padding_mask: Optional[Tensor] = None): + def forward(self, + x: Tensor, + key_padding_mask: Optional[Tensor] = None): """ Args: x: a Tensor of shape (T, N, C) @@ -1463,7 +1430,7 @@ class PoolingModule(nn.Module): """ if key_padding_mask is not None: pooling_mask = key_padding_mask.logical_not().to(x.dtype) # (N, T) - pooling_mask = pooling_mask / pooling_mask.sum(dim=1, keepdim=True) + pooling_mask = (pooling_mask / pooling_mask.sum(dim=1, keepdim=True)) pooling_mask = pooling_mask.transpose(0, 1).contiguous().unsqueeze(-1) # now pooling_mask: (T, N, 1) x = (x * pooling_mask).sum(dim=0, keepdim=True) @@ -1477,19 +1444,24 @@ class PoolingModule(nn.Module): class FeedforwardModule(nn.Module): - """Feedforward module in Zipformer model.""" - - def __init__(self, d_model: int, feedforward_dim: int, dropout: float): + """Feedforward module in Zipformer model. + """ + def __init__(self, + d_model: int, + feedforward_dim: int, + dropout: float): super(FeedforwardModule, self).__init__() self.in_proj = nn.Linear(d_model, feedforward_dim) - self.balancer = ActivationBalancer( - feedforward_dim, channel_dim=-1, max_abs=10.0, min_prob=0.25 - ) + self.balancer = ActivationBalancer(feedforward_dim, + channel_dim=-1, max_abs=10.0, + min_prob=0.25) self.activation = DoubleSwish() self.dropout = nn.Dropout(dropout) - self.out_proj = ScaledLinear(feedforward_dim, d_model, initial_scale=0.01) + self.out_proj = ScaledLinear(feedforward_dim, d_model, + initial_scale=0.01) - def forward(self, x: Tensor): + def forward(self, + x: Tensor): x = self.in_proj(x) x = self.balancer(x) x = self.activation(x) @@ -1509,7 +1481,9 @@ class ConvolutionModule(nn.Module): """ - def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: + def __init__( + self, channels: int, kernel_size: int, bias: bool = True + ) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding @@ -1539,10 +1513,7 @@ class ConvolutionModule(nn.Module): # the correct range. self.deriv_balancer1 = ActivationBalancer( 2 * channels, - channel_dim=1, - max_abs=10.0, - min_positive=0.05, - max_positive=1.0, + channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0 ) self.depthwise_conv = nn.Conv1d( @@ -1556,10 +1527,8 @@ class ConvolutionModule(nn.Module): ) self.deriv_balancer2 = ActivationBalancer( - channels, - channel_dim=1, - min_positive=0.05, - max_positive=1.0, + channels, channel_dim=1, + min_positive=0.05, max_positive=1.0, max_abs=20.0, ) @@ -1575,10 +1544,9 @@ class ConvolutionModule(nn.Module): initial_scale=0.05, ) - def forward( - self, - x: Tensor, - src_key_padding_mask: Optional[Tensor] = None, + def forward(self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: """Compute convolution module. @@ -1658,7 +1626,8 @@ class Conv2dSubsampling(nn.Module): kernel_size=3, padding=(0, 1), # (time, freq) ), - ActivationBalancer(layer1_channels, channel_dim=1), + ActivationBalancer(layer1_channels, + channel_dim=1), DoubleSwish(), nn.Conv2d( in_channels=layer1_channels, @@ -1667,21 +1636,24 @@ class Conv2dSubsampling(nn.Module): stride=2, padding=0, ), - ActivationBalancer(layer2_channels, channel_dim=1), + ActivationBalancer(layer2_channels, + channel_dim=1), DoubleSwish(), nn.Conv2d( in_channels=layer2_channels, out_channels=layer3_channels, kernel_size=3, - stride=(1, 2), # (time, freq) + stride=(1, 2), # (time, freq) ), - ActivationBalancer(layer3_channels, channel_dim=1), + ActivationBalancer(layer3_channels, + channel_dim=1), DoubleSwish(), ) out_height = (((in_channels - 1) // 2) - 1) // 2 self.out = ScaledLinear(out_height * layer3_channels, out_channels) self.dropout = nn.Dropout(dropout) + def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. @@ -1702,7 +1674,6 @@ class Conv2dSubsampling(nn.Module): x = self.dropout(x) return x - class AttentionCombine(nn.Module): """ This module combines a list of Tensors, all with the same shape, to @@ -1746,12 +1717,15 @@ class AttentionCombine(nn.Module): self.random_prob = random_prob self.single_prob = single_prob - self.weight = torch.nn.Parameter(torch.zeros(num_channels, num_inputs)) + self.weight = torch.nn.Parameter(torch.zeros(num_channels, + num_inputs)) self.bias = torch.nn.Parameter(torch.zeros(num_inputs)) assert 0 <= random_prob <= 1, random_prob assert 0 <= single_prob <= 1, single_prob + + def forward(self, inputs: List[Tensor]) -> Tensor: """Forward function. Args: @@ -1782,35 +1756,28 @@ class AttentionCombine(nn.Module): if self.training: # random masking.. - mask_start = torch.randint( - low=1, - high=int(num_inputs / self.random_prob), - size=(num_frames,), - device=scores.device, - ).unsqueeze(1) + mask_start = torch.randint(low=1, high=int(num_inputs / self.random_prob), + size=(num_frames,), device=scores.device).unsqueeze(1) # mask will have rows like: [ False, False, False, True, True, .. ] - arange = ( - torch.arange(num_inputs, device=scores.device) - .unsqueeze(0) - .expand(num_frames, num_inputs) - ) + arange = torch.arange(num_inputs, device=scores.device).unsqueeze(0).expand( + num_frames, num_inputs) mask = arange >= mask_start - apply_single_prob = torch.logical_and( - torch.rand(size=(num_frames, 1), device=scores.device) - < self.single_prob, - mask_start < num_inputs, - ) - single_prob_mask = torch.logical_and( - apply_single_prob, arange < mask_start - 1 - ) + apply_single_prob = torch.logical_and(torch.rand(size=(num_frames, 1), + device=scores.device) < self.single_prob, + mask_start < num_inputs) + single_prob_mask = torch.logical_and(apply_single_prob, + arange < mask_start - 1) - mask = torch.logical_or(mask, single_prob_mask) + mask = torch.logical_or(mask, + single_prob_mask) - scores = scores.masked_fill(mask, float("-inf")) + scores = scores.masked_fill(mask, float('-inf')) if self.training and random.random() < 0.1: - scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04) + scores = penalize_abs_values_gt(scores, + limit=10.0, + penalty=1.0e-04) weights = scores.softmax(dim=1) @@ -1825,6 +1792,7 @@ class AttentionCombine(nn.Module): return ans + def _test_random_combine(): print("_test_random_combine()") num_inputs = 3 @@ -1833,8 +1801,8 @@ def _test_random_combine(): num_channels=num_channels, num_inputs=num_inputs, random_prob=0.5, - single_prob=0.0, - ) + single_prob=0.0) + x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)] @@ -1851,10 +1819,7 @@ def _test_zipformer_main(): # Just make sure the forward pass runs. c = Zipformer( - num_features=feature_dim, - encoder_dims=(64, 96), - encoder_unmasked_dims=(48, 64), - nhead=(4, 4), + num_features=feature_dim, encoder_dims=(64,96), encoder_unmasked_dims=(48,64), nhead=(4,4) ) batch_size = 5 seq_len = 20 @@ -1872,18 +1837,19 @@ def _test_zipformer_main(): ) f # to remove flake8 warnings - def _test_conv2d_subsampling(): num_features = 80 encoder_dims = 384 dropout = 0.1 - encoder_embed = Conv2dSubsampling(num_features, encoder_dims, dropout=dropout) + encoder_embed = Conv2dSubsampling(num_features, encoder_dims, + dropout=dropout) for i in range(20, 40): x = torch.rand(2, i, num_features) y = encoder_embed(x) assert (x.shape[1] - 7) // 2 == y.shape[1], (x.shape[1], y.shape[1]) + if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py index 822f8e44b..9d7335e77 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py @@ -165,24 +165,20 @@ def get_parser(): "--avg", type=int, default=9, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -277,7 +273,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -397,7 +394,9 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] @@ -456,7 +455,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -587,7 +589,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -620,7 +624,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -675,7 +680,9 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -712,12 +719,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -745,12 +753,13 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -779,7 +788,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -807,7 +816,9 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/export.py b/egs/librispeech/ASR/pruned_transducer_stateless8/export.py index 43eb0c1bc..49f469e29 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/export.py @@ -129,24 +129,20 @@ def get_parser(): "--avg", type=int, default=9, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -180,7 +176,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) add_model_arguments(parser) @@ -220,12 +217,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -254,12 +252,13 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -288,7 +287,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -327,7 +326,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py index ed920dc03..e79a3a3aa 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py @@ -69,12 +69,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) return parser @@ -95,9 +93,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -268,7 +267,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/model.py b/egs/librispeech/ASR/pruned_transducer_stateless8/model.py index 39a360796..497b89136 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/model.py @@ -160,7 +160,9 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py index 716136812..373a48fc1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py @@ -100,11 +100,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -129,12 +127,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -181,7 +177,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -212,9 +209,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -277,11 +275,15 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lengths + ) num_waves = encoder_out.size(0) hyps = [] @@ -353,7 +355,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py index 381a86a67..2603bb854 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py @@ -92,7 +92,9 @@ from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: @@ -130,10 +132,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--encoder-dims", type=str, default="384,384,384,384,384", - help=( - "Embedding dimension in the 2 blocks of zipformer encoder layers, comma" - " separated" - ), + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", ) parser.add_argument( @@ -148,11 +147,9 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--encoder-unmasked-dims", type=str, default="256,256,256,256,256", - help=( - "Unmasked dimensions in the encoders, relates to augmentation during" - " training. Must be <= each of encoder_dims. Empirically, less than 256" - " seems to make performance worse." - ), + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", ) parser.add_argument( @@ -217,7 +214,8 @@ def get_parser(): "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. " + "Otherwise, use 100h subset.", ) parser.add_argument( @@ -287,45 +285,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -696,7 +691,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -745,7 +744,9 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -951,7 +952,9 @@ def train_one_epoch( # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + if cur_grad_scale < 1.0 or ( + cur_grad_scale < 8.0 and batch_idx % 400 == 0 + ): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") @@ -972,7 +975,11 @@ def train_one_epoch( f"giga_tot_loss[{giga_tot_loss}], " f"batch size: {batch_size}, " f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + + ( + f"grad_scale: {scaler._scale.item()}" + if params.use_fp16 + else "" + ) ) if tb_writer is not None: @@ -985,8 +992,12 @@ def train_one_epoch( f"train/current_{prefix}_", params.batch_idx_train, ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) libri_tot_loss.write_summary( tb_writer, "train/libri_tot_", params.batch_idx_train ) @@ -1000,7 +1011,10 @@ def train_one_epoch( params.batch_idx_train, ) - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + if ( + batch_idx % params.valid_interval == 0 + and not params.print_diagnostics + ): logging.info("Computing validation loss") valid_info = compute_validation_loss( params=params, @@ -1012,8 +1026,7 @@ def train_one_epoch( model.train() logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") logging.info( - "Maximum memory allocated so far is" - f" {torch.cuda.max_memory_allocated()//1000000}MB" + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" ) if tb_writer is not None: valid_info.write_summary( @@ -1041,7 +1054,8 @@ def filter_short_and_long_utterances( # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False @@ -1138,7 +1152,9 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) - optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0) + optimizer = ScaledAdam( + model.parameters(), lr=params.base_lr, clipping_scale=2.0 + ) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) @@ -1156,7 +1172,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 2 ** 22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1191,7 +1207,9 @@ def run(rank, world_size, args): train_giga_cuts = train_giga_cuts.repeat(times=None) if args.enable_musan: - cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") + cuts_musan = load_manifest( + Path(args.manifest_dir) / "musan_cuts.jsonl.gz" + ) else: cuts_musan = None @@ -1346,8 +1364,7 @@ def scan_pessimistic_batches_for_oom( display_and_save_batch(batch, params=params, sp=sp) raise logging.info( - "Maximum memory allocated so far is" - f" {torch.cuda.max_memory_allocated()//1000000}MB" + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" ) diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/README.md b/egs/librispeech/ASR/streaming_conformer_ctc/README.md index 53f383c99..01be7090b 100644 --- a/egs/librispeech/ASR/streaming_conformer_ctc/README.md +++ b/egs/librispeech/ASR/streaming_conformer_ctc/README.md @@ -1,20 +1,20 @@ ## Train and Decode -Commands of data preparation/train/decode steps are almost the same with +Commands of data preparation/train/decode steps are almost the same with ../conformer_ctc experiment except some options. Please read the code and understand following new added options before running this experiment: For data preparation: - + Nothing new. For streaming_conformer_ctc/train.py: - + --dynamic-chunk-training --short-chunk-proportion For streaming_conformer_ctc/streaming_decode.py: - + --chunk-size --tailing-num-frames --simulate-streaming @@ -57,10 +57,10 @@ And check md5sum values again. Finally, following files will be downloaded:

-streaming_models/
-|-- lang_bpe
-|   |-- L.pt
-|   |-- Linv.pt
+streaming_models/  
+|-- lang_bpe  
+|   |-- L.pt  
+|   |-- Linv.pt  
 |   |-- bpe.model
 |   |-- tokens.txt
 |   `-- words.txt
diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py b/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py
index 4f7427c1f..ff4c91446 100644
--- a/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py
+++ b/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py
@@ -309,26 +309,36 @@ class Conformer(Transformer):
 
                 # start chunk_by_chunk decoding
                 offset = 0
-                for cur in range(0, num_frames - embed_left_context + 1, stride):
+                for cur in range(
+                    0, num_frames - embed_left_context + 1, stride
+                ):
                     end = min(cur + decoding_window, num_frames)
                     cur_feature = feature[:, cur:end, :]
                     cur_feature = self.encoder_embed(cur_feature)
-                    cur_embed, cur_pos_emb = self.encoder_pos(cur_feature, offset)
-                    cur_embed = cur_embed.permute(1, 0, 2)  # (B, T, F) -> (T, B, F)
+                    cur_embed, cur_pos_emb = self.encoder_pos(
+                        cur_feature, offset
+                    )
+                    cur_embed = cur_embed.permute(
+                        1, 0, 2
+                    )  # (B, T, F) -> (T, B, F)
 
                     cur_T = cur_feature.size(1)
                     if cur == 0:
                         # for first chunk extract the central pos embedding
-                        pos_emb_central = cur_pos_emb[0, (chunk_size - 1), :].view(
-                            1, 1, -1
-                        )
+                        pos_emb_central = cur_pos_emb[
+                            0, (chunk_size - 1), :
+                        ].view(1, 1, -1)
                         cur_T -= 1
                     pos_emb_positive.append(cur_pos_emb[0, :cur_T].flip(0))
                     pos_emb_negative.append(cur_pos_emb[0, -cur_T:])
                     assert pos_emb_positive[-1].size(0) == cur_T
 
-                    pos_emb_pos = torch.cat(pos_emb_positive, dim=0).unsqueeze(0)
-                    pos_emb_neg = torch.cat(pos_emb_negative, dim=0).unsqueeze(0)
+                    pos_emb_pos = torch.cat(pos_emb_positive, dim=0).unsqueeze(
+                        0
+                    )
+                    pos_emb_neg = torch.cat(pos_emb_negative, dim=0).unsqueeze(
+                        0
+                    )
                     cur_pos_emb = torch.cat(
                         [pos_emb_pos.flip(1), pos_emb_central, pos_emb_neg],
                         dim=1,
@@ -403,7 +413,9 @@ class ConformerEncoderLayer(nn.Module):
         causal: bool = False,
     ) -> None:
         super(ConformerEncoderLayer, self).__init__()
-        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
+        self.self_attn = RelPositionMultiheadAttention(
+            d_model, nhead, dropout=0.0
+        )
 
         self.feed_forward = nn.Sequential(
             nn.Linear(d_model, dim_feedforward),
@@ -419,16 +431,22 @@ class ConformerEncoderLayer(nn.Module):
             nn.Linear(dim_feedforward, d_model),
         )
 
-        self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal)
+        self.conv_module = ConvolutionModule(
+            d_model, cnn_module_kernel, causal=causal
+        )
 
-        self.norm_ff_macaron = nn.LayerNorm(d_model)  # for the macaron style FNN module
+        self.norm_ff_macaron = nn.LayerNorm(
+            d_model
+        )  # for the macaron style FNN module
         self.norm_ff = nn.LayerNorm(d_model)  # for the FNN module
         self.norm_mha = nn.LayerNorm(d_model)  # for the MHA module
 
         self.ff_scale = 0.5
 
         self.norm_conv = nn.LayerNorm(d_model)  # for the CNN module
-        self.norm_final = nn.LayerNorm(d_model)  # for the final output of the block
+        self.norm_final = nn.LayerNorm(
+            d_model
+        )  # for the final output of the block
 
         self.dropout = nn.Dropout(dropout)
 
@@ -462,7 +480,9 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
+        src = residual + self.ff_scale * self.dropout(
+            self.feed_forward_macaron(src)
+        )
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -534,7 +554,9 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
+        src = residual + self.ff_scale * self.dropout(
+            self.feed_forward_macaron(src)
+        )
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -714,7 +736,9 @@ class RelPositionalEncoding(torch.nn.Module):
 
     """
 
-    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
+    def __init__(
+        self, d_model: int, dropout_rate: float, max_len: int = 5000
+    ) -> None:
         """Construct an PositionalEncoding object."""
         super(RelPositionalEncoding, self).__init__()
         self.d_model = d_model
@@ -731,7 +755,9 @@ class RelPositionalEncoding(torch.nn.Module):
             # the length of self.pe is 2 * input_len - 1
             if self.pe.size(1) >= x_size_1 * 2 - 1:
                 # Note: TorchScript doesn't implement operator== for torch.Device
-                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(
+                    x.device
+                ):
                     self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                 return
         # Suppose `i` means to the position of query vector and `j` means the
@@ -757,7 +783,9 @@ class RelPositionalEncoding(torch.nn.Module):
         pe = torch.cat([pe_positive, pe_negative], dim=1)
         self.pe = pe.to(device=x.device, dtype=x.dtype)
 
-    def forward(self, x: torch.Tensor, offset: int = 0) -> Tuple[Tensor, Tensor]:
+    def forward(
+        self, x: torch.Tensor, offset: int = 0
+    ) -> Tuple[Tensor, Tensor]:
         """Add positional encoding.
 
         Args:
@@ -785,7 +813,9 @@ class RelPositionalEncoding(torch.nn.Module):
             pos_emb = torch.cat(
                 [
                     pos_emb[:, : (x_T - 1)],
-                    self.pe[0, self.pe.size(1) // 2].view(1, 1, self.pe.size(-1)),
+                    self.pe[0, self.pe.size(1) // 2].view(
+                        1, 1, self.pe.size(-1)
+                    ),
                     pos_emb[:, -(x_T - 1) :],  # noqa: E203
                 ],
                 dim=1,
@@ -1020,9 +1050,9 @@ class RelPositionMultiheadAttention(nn.Module):
 
         if torch.equal(query, key) and torch.equal(key, value):
             # self-attention
-            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
-                3, dim=-1
-            )
+            q, k, v = nn.functional.linear(
+                query, in_proj_weight, in_proj_bias
+            ).chunk(3, dim=-1)
 
         elif torch.equal(key, value):
             # encoder-decoder attention
@@ -1090,25 +1120,33 @@ class RelPositionMultiheadAttention(nn.Module):
             if attn_mask.dim() == 2:
                 attn_mask = attn_mask.unsqueeze(0)
                 if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
-                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
+                    raise RuntimeError(
+                        "The size of the 2D attn_mask is not correct."
+                    )
             elif attn_mask.dim() == 3:
                 if list(attn_mask.size()) != [
                     bsz * num_heads,
                     query.size(0),
                     key.size(0),
                 ]:
-                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
+                    raise RuntimeError(
+                        "The size of the 3D attn_mask is not correct."
+                    )
             else:
                 raise RuntimeError(
-                    "attn_mask's dimension {} is not supported".format(attn_mask.dim())
+                    "attn_mask's dimension {} is not supported".format(
+                        attn_mask.dim()
+                    )
                 )
             # attn_mask's dim is 3 now.
 
         # convert ByteTensor key_padding_mask to bool
-        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
+        if (
+            key_padding_mask is not None
+            and key_padding_mask.dtype == torch.uint8
+        ):
             warnings.warn(
-                "Byte tensor for key_padding_mask is deprecated. Use bool tensor"
-                " instead."
+                "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
             )
             key_padding_mask = key_padding_mask.to(torch.bool)
 
@@ -1147,16 +1185,24 @@ class RelPositionMultiheadAttention(nn.Module):
         # first compute matrix a and matrix c
         # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
         k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
-        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(
+            q_with_bias_u, k
+        )  # (batch, head, time1, time2)
 
         # compute matrix b and matrix d
-        matrix_bd = torch.matmul(q_with_bias_v, p)  # (batch, head, time1, 2*time1-1)
-        matrix_bd = self.rel_shift(matrix_bd, offset=offset)  # [B, head, time1, time2]
+        matrix_bd = torch.matmul(
+            q_with_bias_v, p
+        )  # (batch, head, time1, 2*time1-1)
+        matrix_bd = self.rel_shift(
+            matrix_bd, offset=offset
+        )  # [B, head, time1, time2]
         attn_output_weights = (
             matrix_ac + matrix_bd
         ) * scaling  # (batch, head, time1, time2)
 
-        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
+        attn_output_weights = attn_output_weights.view(
+            bsz * num_heads, tgt_len, -1
+        )
 
         assert list(attn_output_weights.size()) == [
             bsz * num_heads,
@@ -1190,9 +1236,13 @@ class RelPositionMultiheadAttention(nn.Module):
         attn_output = torch.bmm(attn_output_weights, v)
         assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
         attn_output = (
-            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+            attn_output.transpose(0, 1)
+            .contiguous()
+            .view(tgt_len, bsz, embed_dim)
+        )
+        attn_output = nn.functional.linear(
+            attn_output, out_proj_weight, out_proj_bias
         )
-        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
 
         if need_weights:
             # average attention weights over heads
diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py b/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py
index 5a8149aad..a74c51836 100755
--- a/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py
+++ b/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py
@@ -28,7 +28,6 @@ import torch
 import torch.nn as nn
 from asr_datamodule import LibriSpeechAsrDataModule
 from conformer import Conformer
-
 from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
 from icefall.checkpoint import average_checkpoints, load_checkpoint
 from icefall.lexicon import Lexicon
@@ -63,36 +62,32 @@ def get_parser():
         "--epoch",
         type=int,
         default=34,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=20,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
         "--chunk-size",
         type=int,
         default=8,
-        help=(
-            "Frames of right context"
-            "-1 for whole right context, i.e. non-streaming decoding"
-        ),
+        help="Frames of right context"
+        "-1 for whole right context, i.e. non-streaming decoding",
     )
 
     parser.add_argument(
         "--tailing-num-frames",
         type=int,
         default=20,
-        help="tailing dummy frames padded to the right,only used during decoding",
+        help="tailing dummy frames padded to the right,"
+        "only used during decoding",
     )
 
     parser.add_argument(
@@ -144,7 +139,8 @@ def get_parser():
         "--avg-models",
         type=str,
         default=None,
-        help="Manually select models to average, seperated by comma;e.g. 60,62,63,72",
+        help="Manually select models to average, seperated by comma;"
+        "e.g. 60,62,63,72",
     )
 
     return parser
@@ -252,9 +248,13 @@ def decode_one_batch(
     maxlen = nnet_output.size(1)
     topk_prob, topk_index = nnet_output.topk(1, dim=2)  # (B, maxlen, 1)
     topk_index = topk_index.view(batch_size, maxlen)  # (B, maxlen)
-    topk_index = topk_index.masked_fill_(memory_key_padding_mask, 0)  # (B, maxlen)
+    topk_index = topk_index.masked_fill_(
+        memory_key_padding_mask, 0
+    )  # (B, maxlen)
     token_ids = [token_id.tolist() for token_id in topk_index]
-    token_ids = [remove_duplicates_and_blank(token_id) for token_id in token_ids]
+    token_ids = [
+        remove_duplicates_and_blank(token_id) for token_id in token_ids
+    ]
     hyps = bpe_model.decode(token_ids)
     hyps = [s.split() for s in hyps]
     return {key: hyps}
@@ -337,7 +337,9 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
 
     return results
 
@@ -355,18 +357,15 @@ def save_results(
     test_set_wers = dict()
     if params.avg_models is not None:
         avg_models = params.avg_models.replace(",", "_")
-        result_file_prefix = (
-            f"epoch-avg-{avg_models}-chunksize        "
-            f" -{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-"
-        )
+        result_file_prefix = f"epoch-avg-{avg_models}-chunksize \
+        -{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-"
     else:
-        result_file_prefix = (
-            f"epoch-{params.epoch}-avg-{params.avg}-chunksize        "
-            f" -{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-"
-        )
+        result_file_prefix = f"epoch-{params.epoch}-avg-{params.avg}-chunksize \
+        -{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-"
     for key, results in results_dict.items():
         recog_path = (
-            params.exp_dir / f"{result_file_prefix}recogs-{test_set_name}-{key}.txt"
+            params.exp_dir
+            / f"{result_file_prefix}recogs-{test_set_name}-{key}.txt"
         )
         store_transcripts(filename=recog_path, texts=results)
         if enable_log:
@@ -375,7 +374,8 @@ def save_results(
         # The following prints out WERs, per-word error statistics and aligned
         # ref/hyp pairs.
         errs_filename = (
-            params.exp_dir / f"{result_file_prefix}-errs-{test_set_name}-{key}.txt"
+            params.exp_dir
+            / f"{result_file_prefix}-errs-{test_set_name}-{key}.txt"
         )
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
@@ -384,7 +384,9 @@ def save_results(
             test_set_wers[key] = wer
 
         if enable_log:
-            logging.info("Wrote detailed error stats to {}".format(errs_filename))
+            logging.info(
+                "Wrote detailed error stats to {}".format(errs_filename)
+            )
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
@@ -472,7 +474,9 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
+        torch.save(
+            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
+        )
         return
 
     model.to(device)
@@ -503,7 +507,9 @@ def main():
             simulate_streaming=params.simulate_streaming,
         )
 
-        save_results(params=params, test_set_name=test_set, results_dict=results_dict)
+        save_results(
+            params=params, test_set_name=test_set, results_dict=results_dict
+        )
 
     logging.info("Done!")
 
diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/train.py b/egs/librispeech/ASR/streaming_conformer_ctc/train.py
index 553b7d092..e41b7ea78 100755
--- a/egs/librispeech/ASR/streaming_conformer_ctc/train.py
+++ b/egs/librispeech/ASR/streaming_conformer_ctc/train.py
@@ -405,7 +405,9 @@ def compute_loss(
             #
             # See https://github.com/k2-fsa/icefall/issues/97
             # for more details
-            unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"])
+            unsorted_token_ids = graph_compiler.texts_to_ids(
+                supervisions["text"]
+            )
             att_loss = mmodel.decoder_forward(
                 encoder_memory,
                 memory_mask,
@@ -434,7 +436,9 @@ def compute_loss(
     info["utt_duration"] = supervisions["num_frames"].sum().item()
     # averaged padding proportion over utterances
     info["utt_pad_proportion"] = (
-        ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)).sum().item()
+        ((feature.size(1) - supervisions["num_frames"]) / feature.size(1))
+        .sum()
+        .item()
     )
 
     return loss, info
@@ -547,7 +551,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -662,7 +668,9 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
+            tb_writer.add_scalar(
+                "train/learning_rate", cur_lr, params.batch_idx_train
+            )
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py b/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py
index 0c87fdf1b..bc78e4a41 100644
--- a/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py
+++ b/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py
@@ -149,7 +149,9 @@ class Transformer(nn.Module):
                 norm=decoder_norm,
             )
 
-            self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class)
+            self.decoder_output_layer = torch.nn.Linear(
+                d_model, self.decoder_num_class
+            )
 
             self.decoder_criterion = LabelSmoothingLoss()
         else:
@@ -284,17 +286,23 @@ class Transformer(nn.Module):
         """
         ys_in = add_sos(token_ids, sos_id=sos_id)
         ys_in = [torch.tensor(y) for y in ys_in]
-        ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
+        ys_in_pad = pad_sequence(
+            ys_in, batch_first=True, padding_value=float(eos_id)
+        )
 
         ys_out = add_eos(token_ids, eos_id=eos_id)
         ys_out = [torch.tensor(y) for y in ys_out]
-        ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
+        ys_out_pad = pad_sequence(
+            ys_out, batch_first=True, padding_value=float(-1)
+        )
 
         device = memory.device
         ys_in_pad = ys_in_pad.to(device)
         ys_out_pad = ys_out_pad.to(device)
 
-        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
+        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
+            device
+        )
 
         tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
         # TODO: Use length information to create the decoder padding mask
@@ -355,17 +363,23 @@ class Transformer(nn.Module):
 
         ys_in = add_sos(token_ids, sos_id=sos_id)
         ys_in = [torch.tensor(y) for y in ys_in]
-        ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
+        ys_in_pad = pad_sequence(
+            ys_in, batch_first=True, padding_value=float(eos_id)
+        )
 
         ys_out = add_eos(token_ids, eos_id=eos_id)
         ys_out = [torch.tensor(y) for y in ys_out]
-        ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
+        ys_out_pad = pad_sequence(
+            ys_out, batch_first=True, padding_value=float(-1)
+        )
 
         device = memory.device
         ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
         ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
 
-        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
+        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
+            device
+        )
 
         tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
         # TODO: Use length information to create the decoder padding mask
@@ -638,7 +652,9 @@ def _get_activation_fn(activation: str):
     elif activation == "gelu":
         return nn.functional.gelu
 
-    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
+    raise RuntimeError(
+        "activation should be relu/gelu, not {}".format(activation)
+    )
 
 
 class PositionalEncoding(nn.Module):
@@ -840,7 +856,9 @@ def encoder_padding_mask(
         1,
     ).to(torch.int32)
 
-    lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)]
+    lengths = [
+        0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)
+    ]
     for idx in range(supervision_segments.size(0)):
         # Note: TorchScript doesn't allow to unpack tensors as tuples
         sequence_idx = supervision_segments[idx, 0].item()
@@ -861,7 +879,9 @@ def encoder_padding_mask(
     return mask
 
 
-def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor:
+def decoder_padding_mask(
+    ys_pad: torch.Tensor, ignore_id: int = -1
+) -> torch.Tensor:
     """Generate a length mask for input.
 
     The masked position are filled with True,
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
index 63afd6be2..355ccc99a 100644
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
@@ -77,18 +77,17 @@ class LibriSpeechAsrDataModule:
     def add_arguments(cls, parser: argparse.ArgumentParser):
         group = parser.add_argument_group(
             title="ASR data related options",
-            description=(
-                "These options are used for the preparation of "
-                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-                "effective batch sizes, sampling strategies, applied data "
-                "augmentations, etc."
-            ),
+            description="These options are used for the preparation of "
+            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+            "effective batch sizes, sampling strategies, applied data "
+            "augmentations, etc.",
         )
         group.add_argument(
             "--full-libri",
             type=str2bool,
             default=True,
-            help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.",
+            help="When enabled, use 960h LibriSpeech. "
+            "Otherwise, use 100h subset.",
         )
         group.add_argument(
             "--manifest-dir",
@@ -100,74 +99,59 @@ class LibriSpeechAsrDataModule:
             "--max-duration",
             type=int,
             default=200.0,
-            help=(
-                "Maximum pooled recordings duration (seconds) in a "
-                "single batch. You can reduce it if it causes CUDA OOM."
-            ),
+            help="Maximum pooled recordings duration (seconds) in a "
+            "single batch. You can reduce it if it causes CUDA OOM.",
         )
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, the batches will come from buckets of "
-                "similar duration (saves padding frames)."
-            ),
+            help="When enabled, the batches will come from buckets of "
+            "similar duration (saves padding frames).",
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=30,
-            help=(
-                "The number of buckets for the DynamicBucketingSampler"
-                "(you might want to increase it for larger datasets)."
-            ),
+            help="The number of buckets for the DynamicBucketingSampler"
+            "(you might want to increase it for larger datasets).",
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, utterances (cuts) will be concatenated "
-                "to minimize the amount of padding."
-            ),
+            help="When enabled, utterances (cuts) will be concatenated "
+            "to minimize the amount of padding.",
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help=(
-                "Determines the maximum duration of a concatenated cut "
-                "relative to the duration of the longest cut in a batch."
-            ),
+            help="Determines the maximum duration of a concatenated cut "
+            "relative to the duration of the longest cut in a batch.",
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help=(
-                "The amount of padding (in seconds) inserted between "
-                "concatenated cuts. This padding is filled with noise when "
-                "noise augmentation is used."
-            ),
+            help="The amount of padding (in seconds) inserted between "
+            "concatenated cuts. This padding is filled with noise when "
+            "noise augmentation is used.",
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, use on-the-fly cut mixing and feature "
-                "extraction. Will drop existing precomputed feature manifests "
-                "if available."
-            ),
+            help="When enabled, use on-the-fly cut mixing and feature "
+            "extraction. Will drop existing precomputed feature manifests "
+            "if available.",
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled (=default), the examples will be shuffled for each epoch."
-            ),
+            help="When enabled (=default), the examples will be "
+            "shuffled for each epoch.",
         )
         group.add_argument(
             "--drop-last",
@@ -179,18 +163,17 @@ class LibriSpeechAsrDataModule:
             "--return-cuts",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, each batch will have the "
-                "field: batch['supervisions']['cut'] with the cuts that "
-                "were used to construct it."
-            ),
+            help="When enabled, each batch will have the "
+            "field: batch['supervisions']['cut'] with the cuts that "
+            "were used to construct it.",
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that collect the batches.",
+            help="The number of training dataloader workers that "
+            "collect the batches.",
         )
 
         group.add_argument(
@@ -204,22 +187,18 @@ class LibriSpeechAsrDataModule:
             "--spec-aug-time-warp-factor",
             type=int,
             default=80,
-            help=(
-                "Used only when --enable-spec-aug is True. "
-                "It specifies the factor for time warping in SpecAugment. "
-                "Larger values mean more warping. "
-                "A value less than 1 means to disable time warp."
-            ),
+            help="Used only when --enable-spec-aug is True. "
+            "It specifies the factor for time warping in SpecAugment. "
+            "Larger values mean more warping. "
+            "A value less than 1 means to disable time warp.",
         )
 
         group.add_argument(
             "--enable-musan",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, select noise from MUSAN and mix it"
-                "with training dataset. "
-            ),
+            help="When enabled, select noise from MUSAN and mix it"
+            "with training dataset. ",
         )
 
         group.add_argument(
@@ -245,16 +224,20 @@ class LibriSpeechAsrDataModule:
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             logging.info("About to get Musan cuts")
-            cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
+            cuts_musan = load_manifest(
+                self.args.manifest_dir / "musan_cuts.jsonl.gz"
+            )
             transforms.append(
-                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+                CutMix(
+                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
+                )
             )
         else:
             logging.info("Disable MUSAN")
 
         if self.args.concatenate_cuts:
             logging.info(
-                "Using cut concatenation with duration factor "
+                f"Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -269,7 +252,9 @@ class LibriSpeechAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+            logging.info(
+                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
+            )
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -313,7 +298,9 @@ class LibriSpeechAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -369,7 +356,9 @@ class LibriSpeechAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 return_cuts=self.args.return_cuts,
             )
         else:
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
index 94ba0a4dc..7d0cd0bf3 100755
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
@@ -57,19 +57,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=19,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=5,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
     parser.add_argument(
         "--method",
@@ -339,7 +336,9 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -401,7 +400,9 @@ def main():
 
     logging.info(f"device: {device}")
 
-    HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
+    HLG = k2.Fsa.from_dict(
+        torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
+    )
     HLG = HLG.to(device)
     assert HLG.requires_grad is False
 
@@ -466,7 +467,9 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
+        torch.save(
+            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
+        )
         return
 
     model.to(device)
@@ -495,7 +498,9 @@ def main():
             G=G,
         )
 
-        save_results(params=params, test_set_name=test_set, results_dict=results_dict)
+        save_results(
+            params=params, test_set_name=test_set, results_dict=results_dict
+        )
 
     logging.info("Done!")
 
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/model.py b/egs/librispeech/ASR/tdnn_lstm_ctc/model.py
index 1731e1ebe..5e04c11b4 100644
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/model.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/model.py
@@ -66,7 +66,10 @@ class TdnnLstm(nn.Module):
             nn.BatchNorm1d(num_features=500, affine=False),
         )
         self.lstms = nn.ModuleList(
-            [nn.LSTM(input_size=500, hidden_size=500, num_layers=1) for _ in range(5)]
+            [
+                nn.LSTM(input_size=500, hidden_size=500, num_layers=1)
+                for _ in range(5)
+            ]
         )
         self.lstm_bnorms = nn.ModuleList(
             [nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)]
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
index 722e8f003..2baeb6bba 100755
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
@@ -29,7 +29,11 @@ import torchaudio
 from model import TdnnLstm
 from torch.nn.utils.rnn import pad_sequence
 
-from icefall.decode import get_lattice, one_best_decoding, rescore_with_whole_lattice
+from icefall.decode import (
+    get_lattice,
+    one_best_decoding,
+    rescore_with_whole_lattice,
+)
 from icefall.utils import AttributeDict, get_texts
 
 
@@ -42,11 +46,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -56,7 +58,9 @@ def get_parser():
         help="Path to words.txt",
     )
 
-    parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
+    parser.add_argument(
+        "--HLG", type=str, required=True, help="Path to HLG.pt."
+    )
 
     parser.add_argument(
         "--method",
@@ -99,12 +103,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     return parser
@@ -142,9 +144,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -212,7 +215,9 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
     features = features.permute(0, 2, 1)  # now features is (N, C, T)
 
     with torch.no_grad():
@@ -264,7 +269,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py
index 071ac792b..6b37d5c23 100755
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py
@@ -355,7 +355,9 @@ def compute_loss(
     info["utt_duration"] = supervisions["num_frames"].sum().item()
     # averaged padding proportion over utterances
     info["utt_pad_proportion"] = (
-        ((feature.size(2) - supervisions["num_frames"]) / feature.size(2)).sum().item()
+        ((feature.size(2) - supervisions["num_frames"]) / feature.size(2))
+        .sum()
+        .item()
     )
 
     return loss, info
@@ -468,7 +470,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/egs/librispeech/ASR/transducer/beam_search.py b/egs/librispeech/ASR/transducer/beam_search.py
index b45b6a9d8..11032f31a 100644
--- a/egs/librispeech/ASR/transducer/beam_search.py
+++ b/egs/librispeech/ASR/transducer/beam_search.py
@@ -38,7 +38,9 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
     blank_id = model.decoder.blank_id
     device = model.device
 
-    sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(1, 1)
+    sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(
+        1, 1
+    )
     decoder_out, (h, c) = model.decoder(sos)
     T = encoder_out.size(1)
     t = 0
@@ -121,7 +123,9 @@ def beam_search(
     max_u = 20000  # terminate after this number of steps
     u = 0
 
-    cache: Dict[str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = {}
+    cache: Dict[
+        str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
+    ] = {}
 
     while t < T and u < max_u:
         # fmt: off
@@ -153,9 +157,9 @@ def beam_search(
             cached_key = "_".join(map(str, y_star.ys))
 
             if cached_key not in cache:
-                decoder_input = torch.tensor([y_star.ys[-1]], device=device).reshape(
-                    1, 1
-                )
+                decoder_input = torch.tensor(
+                    [y_star.ys[-1]], device=device
+                ).reshape(1, 1)
 
                 decoder_out, decoder_state = model.decoder(
                     decoder_input,
diff --git a/egs/librispeech/ASR/transducer/decode.py b/egs/librispeech/ASR/transducer/decode.py
index f30332cea..5f233df87 100755
--- a/egs/librispeech/ASR/transducer/decode.py
+++ b/egs/librispeech/ASR/transducer/decode.py
@@ -71,19 +71,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=34,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=11,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -231,7 +228,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
     hyps = []
     batch_size = encoder_out.size(0)
 
@@ -246,7 +245,9 @@ def decode_one_batch(
                 model=model, encoder_out=encoder_out_i, beam=params.beam_size
             )
         else:
-            raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
+            raise ValueError(
+                f"Unsupported decoding method: {params.decoding_method}"
+            )
         hyps.append(sp.decode(hyp).split())
 
     if params.decoding_method == "greedy_search":
@@ -317,7 +318,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -350,7 +353,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
diff --git a/egs/librispeech/ASR/transducer/export.py b/egs/librispeech/ASR/transducer/export.py
index 4d9f937f5..5a5db30c4 100755
--- a/egs/librispeech/ASR/transducer/export.py
+++ b/egs/librispeech/ASR/transducer/export.py
@@ -67,20 +67,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=34,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=11,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -241,7 +238,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer/pretrained.py b/egs/librispeech/ASR/transducer/pretrained.py
index 7aadfbcd1..1db2df648 100755
--- a/egs/librispeech/ASR/transducer/pretrained.py
+++ b/egs/librispeech/ASR/transducer/pretrained.py
@@ -60,11 +60,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -89,12 +87,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -192,9 +188,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -252,7 +249,9 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -288,7 +287,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer/rnn.py b/egs/librispeech/ASR/transducer/rnn.py
index fe8732301..2a165b0c1 100644
--- a/egs/librispeech/ASR/transducer/rnn.py
+++ b/egs/librispeech/ASR/transducer/rnn.py
@@ -117,8 +117,12 @@ class LayerNormLSTMCell(nn.Module):
         )
 
         if bias:
-            self.bias_ih = nn.Parameter(torch.empty(4 * hidden_size, **factory_kwargs))
-            self.bias_hh = nn.Parameter(torch.empty(4 * hidden_size, **factory_kwargs))
+            self.bias_ih = nn.Parameter(
+                torch.empty(4 * hidden_size, **factory_kwargs)
+            )
+            self.bias_hh = nn.Parameter(
+                torch.empty(4 * hidden_size, **factory_kwargs)
+            )
         else:
             self.register_parameter("bias_ih", None)
             self.register_parameter("bias_hh", None)
@@ -344,7 +348,9 @@ class LayerNormLSTM(nn.Module):
             device=device,
             dtype=dtype,
         )
-        first_layer = LayerNormLSTMLayer(input_size=input_size, **factory_kwargs)
+        first_layer = LayerNormLSTMLayer(
+            input_size=input_size, **factory_kwargs
+        )
         layers = [first_layer]
         for i in range(1, num_layers):
             layers.append(
@@ -379,7 +385,9 @@ class LayerNormLSTM(nn.Module):
             - List[(next_h, next_c)] containing the hidden states for all layers
 
         """
-        output_states = torch.jit.annotate(List[Tuple[torch.Tensor, torch.Tensor]], [])
+        output_states = torch.jit.annotate(
+            List[Tuple[torch.Tensor, torch.Tensor]], []
+        )
         output = input
         for i, rnn_layer in enumerate(self.layers):
             state = states[i]
@@ -448,8 +456,12 @@ class LayerNormGRUCell(nn.Module):
         )
 
         if bias:
-            self.bias_ih = nn.Parameter(torch.empty(3 * hidden_size, **factory_kwargs))
-            self.bias_hh = nn.Parameter(torch.empty(3 * hidden_size, **factory_kwargs))
+            self.bias_ih = nn.Parameter(
+                torch.empty(3 * hidden_size, **factory_kwargs)
+            )
+            self.bias_hh = nn.Parameter(
+                torch.empty(3 * hidden_size, **factory_kwargs)
+            )
         else:
             self.register_parameter("bias_ih", None)
             self.register_parameter("bias_hh", None)
diff --git a/egs/librispeech/ASR/transducer/test_rnn.py b/egs/librispeech/ASR/transducer/test_rnn.py
index 74c94cc70..8591e2d8a 100755
--- a/egs/librispeech/ASR/transducer/test_rnn.py
+++ b/egs/librispeech/ASR/transducer/test_rnn.py
@@ -254,7 +254,9 @@ def test_layernorm_lstm_layer_with_projection_forward(device="cpu"):
         for name, self_param in self_layer.cell.named_parameters():
             getattr(torch_layer, f"{name}_l0").copy_(self_param)
 
-    torch_y, (torch_h, torch_c) = torch_layer(x_clone, (h.unsqueeze(0), c.unsqueeze(0)))
+    torch_y, (torch_h, torch_c) = torch_layer(
+        x_clone, (h.unsqueeze(0), c.unsqueeze(0))
+    )
     assert_allclose(self_y, torch_y)
     assert_allclose(self_h, torch_h)
     assert_allclose(self_c, torch_c)
@@ -301,7 +303,9 @@ def test_layernorm_lstm_layer_forward(device="cpu"):
         for name, self_param in self_layer.cell.named_parameters():
             getattr(torch_layer, f"{name}_l0").copy_(self_param)
 
-    torch_y, (torch_h, torch_c) = torch_layer(x_clone, (h.unsqueeze(0), c.unsqueeze(0)))
+    torch_y, (torch_h, torch_c) = torch_layer(
+        x_clone, (h.unsqueeze(0), c.unsqueeze(0))
+    )
     assert_allclose(self_y, torch_y)
     assert_allclose(self_h, torch_h)
     assert_allclose(self_c, torch_c)
@@ -590,7 +594,9 @@ def test_layernorm_gru_cell_forward(device="cpu"):
 
     assert_allclose(self_h, torch_h, atol=1e-5)
 
-    (self_h.reshape(-1) * torch.arange(self_h.numel(), device=device)).sum().backward()
+    (
+        self_h.reshape(-1) * torch.arange(self_h.numel(), device=device)
+    ).sum().backward()
     (
         torch_h.reshape(-1) * torch.arange(torch_h.numel(), device=device)
     ).sum().backward()
@@ -712,7 +718,9 @@ def test_layernorm_gru_forward(device="cpu"):
     T = torch.randint(low=2, high=100, size=(1,))
 
     x = torch.rand(N, T, input_size, device=device).requires_grad_()
-    states = [torch.rand(N, hidden_size, device=device) for _ in range(num_layers)]
+    states = [
+        torch.rand(N, hidden_size, device=device) for _ in range(num_layers)
+    ]
 
     x_clone = x.detach().clone().requires_grad_()
 
diff --git a/egs/librispeech/ASR/transducer/train.py b/egs/librispeech/ASR/transducer/train.py
index 674ea10a6..1dd65eddb 100755
--- a/egs/librispeech/ASR/transducer/train.py
+++ b/egs/librispeech/ASR/transducer/train.py
@@ -396,7 +396,9 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -518,7 +520,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -655,7 +659,9 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
+            tb_writer.add_scalar(
+                "train/learning_rate", cur_lr, params.batch_idx_train
+            )
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/transducer_lstm/beam_search.py b/egs/librispeech/ASR/transducer_lstm/beam_search.py
index 5342c3e8c..3531a9633 100644
--- a/egs/librispeech/ASR/transducer_lstm/beam_search.py
+++ b/egs/librispeech/ASR/transducer_lstm/beam_search.py
@@ -38,7 +38,9 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
     blank_id = model.decoder.blank_id
     device = model.device
 
-    sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(1, 1)
+    sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(
+        1, 1
+    )
     decoder_out, (h, c) = model.decoder(sos)
     T = encoder_out.size(1)
     t = 0
@@ -122,7 +124,9 @@ def beam_search(
     max_u = 20000  # terminate after this number of steps
     u = 0
 
-    cache: Dict[str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = {}
+    cache: Dict[
+        str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
+    ] = {}
 
     while t < T and u < max_u:
         # fmt: off
@@ -154,9 +158,9 @@ def beam_search(
             cached_key = "_".join(map(str, y_star.ys))
 
             if cached_key not in cache:
-                decoder_input = torch.tensor([y_star.ys[-1]], device=device).reshape(
-                    1, 1
-                )
+                decoder_input = torch.tensor(
+                    [y_star.ys[-1]], device=device
+                ).reshape(1, 1)
 
                 decoder_out, decoder_state = model.decoder(
                     decoder_input,
diff --git a/egs/librispeech/ASR/transducer_lstm/decode.py b/egs/librispeech/ASR/transducer_lstm/decode.py
index 61b9de504..604235e2a 100755
--- a/egs/librispeech/ASR/transducer_lstm/decode.py
+++ b/egs/librispeech/ASR/transducer_lstm/decode.py
@@ -71,19 +71,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=77,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=55,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -228,7 +225,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
     hyps = []
     batch_size = encoder_out.size(0)
 
@@ -243,7 +242,9 @@ def decode_one_batch(
                 model=model, encoder_out=encoder_out_i, beam=params.beam_size
             )
         else:
-            raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
+            raise ValueError(
+                f"Unsupported decoding method: {params.decoding_method}"
+            )
         hyps.append(sp.decode(hyp).split())
 
     if params.decoding_method == "greedy_search":
@@ -314,7 +315,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -347,7 +350,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
diff --git a/egs/librispeech/ASR/transducer_lstm/encoder.py b/egs/librispeech/ASR/transducer_lstm/encoder.py
index 038d80077..3dc992dd2 100644
--- a/egs/librispeech/ASR/transducer_lstm/encoder.py
+++ b/egs/librispeech/ASR/transducer_lstm/encoder.py
@@ -48,7 +48,9 @@ class LstmEncoder(EncoderInterface):
         if vgg_frontend:
             self.encoder_embed = VggSubsampling(num_features, real_hidden_size)
         else:
-            self.encoder_embed = Conv2dSubsampling(num_features, real_hidden_size)
+            self.encoder_embed = Conv2dSubsampling(
+                num_features, real_hidden_size
+            )
 
         self.rnn = nn.LSTM(
             input_size=hidden_size,
diff --git a/egs/librispeech/ASR/transducer_lstm/train.py b/egs/librispeech/ASR/transducer_lstm/train.py
index 57bda63fd..cdb801e79 100755
--- a/egs/librispeech/ASR/transducer_lstm/train.py
+++ b/egs/librispeech/ASR/transducer_lstm/train.py
@@ -400,7 +400,9 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -522,7 +524,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -661,7 +665,9 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
+            tb_writer.add_scalar(
+                "train/learning_rate", cur_lr, params.batch_idx_train
+            )
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/transducer_stateless/alignment.py b/egs/librispeech/ASR/transducer_stateless/alignment.py
index 65f2c58d8..f143611ea 100644
--- a/egs/librispeech/ASR/transducer_stateless/alignment.py
+++ b/egs/librispeech/ASR/transducer_stateless/alignment.py
@@ -193,7 +193,9 @@ def force_alignment(
         decoder_out = model.decoder(decoder_input, need_pad=False)
         # decoder_output is of shape (num_active_items, 1, decoder_output_dim)
 
-        current_encoder_out = current_encoder_out.expand(decoder_out.size(0), 1, -1)
+        current_encoder_out = current_encoder_out.expand(
+            decoder_out.size(0), 1, -1
+        )
 
         logits = model.joiner(
             current_encoder_out,
diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py
index 1d79eef9d..ea985f30d 100644
--- a/egs/librispeech/ASR/transducer_stateless/beam_search.py
+++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py
@@ -316,9 +316,9 @@ def greedy_search(
         y = logits.argmax().item()
         if y != blank_id:
             hyp.append(y)
-            decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape(
-                1, context_size
-            )
+            decoder_input = torch.tensor(
+                [hyp[-context_size:]], device=device
+            ).reshape(1, context_size)
 
             decoder_out = model.decoder(decoder_input, need_pad=False)
 
@@ -478,7 +478,9 @@ class HypothesisList(object):
         key = hyp.key
         if key in self:
             old_hyp = self._data[key]  # shallow copy
-            torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob)
+            torch.logaddexp(
+                old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
+            )
         else:
             self._data[key] = hyp
 
@@ -494,7 +496,9 @@ class HypothesisList(object):
           Return the hypothesis that has the largest `log_prob`.
         """
         if length_norm:
-            return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys))
+            return max(
+                self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)
+            )
         else:
             return max(self._data.values(), key=lambda hyp: hyp.log_prob)
 
@@ -782,7 +786,9 @@ def modified_beam_search(
         log_probs_shape = k2.ragged.create_ragged_shape2(
             row_splits=row_splits, cached_tot_size=log_probs.numel()
         )
-        ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
+        ragged_log_probs = k2.RaggedTensor(
+            shape=log_probs_shape, value=log_probs
+        )
 
         for i in range(batch_size):
             topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
@@ -881,7 +887,9 @@ def _deprecated_modified_beam_search(
         decoder_out = model.decoder(decoder_input, need_pad=False)
         # decoder_output is of shape (num_hyps, 1, decoder_output_dim)
 
-        current_encoder_out = current_encoder_out.expand(decoder_out.size(0), 1, -1)
+        current_encoder_out = current_encoder_out.expand(
+            decoder_out.size(0), 1, -1
+        )
 
         logits = model.joiner(
             current_encoder_out,
@@ -951,9 +959,9 @@ def beam_search(
 
     device = model.device
 
-    decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
-        1, context_size
-    )
+    decoder_input = torch.tensor(
+        [blank_id] * context_size, device=device
+    ).reshape(1, context_size)
 
     decoder_out = model.decoder(decoder_input, need_pad=False)
 
diff --git a/egs/librispeech/ASR/transducer_stateless/compute_ali.py b/egs/librispeech/ASR/transducer_stateless/compute_ali.py
index 89992856d..48769e9d1 100755
--- a/egs/librispeech/ASR/transducer_stateless/compute_ali.py
+++ b/egs/librispeech/ASR/transducer_stateless/compute_ali.py
@@ -54,19 +54,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=34,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=20,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -127,7 +124,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     return parser
@@ -164,7 +162,9 @@ def compute_alignments(
 
         feature_lens = supervisions["num_frames"].to(device)
 
-        encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+        encoder_out, encoder_out_lens = model.encoder(
+            x=feature, x_lens=feature_lens
+        )
 
         batch_size = encoder_out.size(0)
 
@@ -204,7 +204,9 @@ def compute_alignments(
         if batch_idx % 2 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
 
     return CutSet.from_cuts(cuts)
 
diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py
index d279eae85..cde52c9fc 100644
--- a/egs/librispeech/ASR/transducer_stateless/conformer.py
+++ b/egs/librispeech/ASR/transducer_stateless/conformer.py
@@ -209,7 +209,10 @@ class Conformer(Transformer):
 
           NOTE: the returned tensors are on the given device.
         """
-        if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context:
+        if (
+            len(self._init_state) == 2
+            and self._init_state[0].size(1) == left_context
+        ):
             # Note: It is OK to share the init state as it is
             # not going to be modified by the model
             return self._init_state
@@ -418,7 +421,9 @@ class ConformerEncoderLayer(nn.Module):
         causal: bool = False,
     ) -> None:
         super(ConformerEncoderLayer, self).__init__()
-        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
+        self.self_attn = RelPositionMultiheadAttention(
+            d_model, nhead, dropout=0.0
+        )
 
         self.feed_forward = nn.Sequential(
             nn.Linear(d_model, dim_feedforward),
@@ -434,16 +439,22 @@ class ConformerEncoderLayer(nn.Module):
             nn.Linear(dim_feedforward, d_model),
         )
 
-        self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal)
+        self.conv_module = ConvolutionModule(
+            d_model, cnn_module_kernel, causal=causal
+        )
 
-        self.norm_ff_macaron = nn.LayerNorm(d_model)  # for the macaron style FNN module
+        self.norm_ff_macaron = nn.LayerNorm(
+            d_model
+        )  # for the macaron style FNN module
         self.norm_ff = nn.LayerNorm(d_model)  # for the FNN module
         self.norm_mha = nn.LayerNorm(d_model)  # for the MHA module
 
         self.ff_scale = 0.5
 
         self.norm_conv = nn.LayerNorm(d_model)  # for the CNN module
-        self.norm_final = nn.LayerNorm(d_model)  # for the final output of the block
+        self.norm_final = nn.LayerNorm(
+            d_model
+        )  # for the final output of the block
 
         self.dropout = nn.Dropout(dropout)
 
@@ -475,7 +486,9 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
+        src = residual + self.ff_scale * self.dropout(
+            self.feed_forward_macaron(src)
+        )
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -501,7 +514,9 @@ class ConformerEncoderLayer(nn.Module):
         if self.normalize_before:
             src = self.norm_conv(src)
 
-        src, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
+        src, _ = self.conv_module(
+            src, src_key_padding_mask=src_key_padding_mask
+        )
         src = residual + self.dropout(src)
 
         if not self.normalize_before:
@@ -566,7 +581,9 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
+        src = residual + self.ff_scale * self.dropout(
+            self.feed_forward_macaron(src)
+        )
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -608,7 +625,9 @@ class ConformerEncoderLayer(nn.Module):
         if self.normalize_before:
             src = self.norm_conv(src)
 
-        src, conv_cache = self.conv_module(src, states[1], right_context=right_context)
+        src, conv_cache = self.conv_module(
+            src, states[1], right_context=right_context
+        )
         states[1] = conv_cache
         src = residual + self.dropout(src)
 
@@ -760,7 +779,9 @@ class RelPositionalEncoding(torch.nn.Module):
 
     """
 
-    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
+    def __init__(
+        self, d_model: int, dropout_rate: float, max_len: int = 5000
+    ) -> None:
         """Construct an PositionalEncoding object."""
         super(RelPositionalEncoding, self).__init__()
         self.d_model = d_model
@@ -777,7 +798,9 @@ class RelPositionalEncoding(torch.nn.Module):
             # the length of self.pe is 2 * input_len - 1
             if self.pe.size(1) >= x_size_1 * 2 - 1:
                 # Note: TorchScript doesn't implement operator== for torch.Device
-                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(
+                    x.device
+                ):
                     self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                 return
         # Suppose `i` means to the position of query vector and `j` means the
@@ -803,7 +826,9 @@ class RelPositionalEncoding(torch.nn.Module):
         pe = torch.cat([pe_positive, pe_negative], dim=1)
         self.pe = pe.to(device=x.device, dtype=x.dtype)
 
-    def forward(self, x: torch.Tensor, left_context: int = 0) -> Tuple[Tensor, Tensor]:
+    def forward(
+        self, x: torch.Tensor, left_context: int = 0
+    ) -> Tuple[Tensor, Tensor]:
         """Add positional encoding.
 
         Args:
@@ -1067,9 +1092,9 @@ class RelPositionMultiheadAttention(nn.Module):
 
         if torch.equal(query, key) and torch.equal(key, value):
             # self-attention
-            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
-                3, dim=-1
-            )
+            q, k, v = nn.functional.linear(
+                query, in_proj_weight, in_proj_bias
+            ).chunk(3, dim=-1)
 
         elif torch.equal(key, value):
             # encoder-decoder attention
@@ -1138,25 +1163,33 @@ class RelPositionMultiheadAttention(nn.Module):
             if attn_mask.dim() == 2:
                 attn_mask = attn_mask.unsqueeze(0)
                 if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
-                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
+                    raise RuntimeError(
+                        "The size of the 2D attn_mask is not correct."
+                    )
             elif attn_mask.dim() == 3:
                 if list(attn_mask.size()) != [
                     bsz * num_heads,
                     query.size(0),
                     key.size(0),
                 ]:
-                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
+                    raise RuntimeError(
+                        "The size of the 3D attn_mask is not correct."
+                    )
             else:
                 raise RuntimeError(
-                    "attn_mask's dimension {} is not supported".format(attn_mask.dim())
+                    "attn_mask's dimension {} is not supported".format(
+                        attn_mask.dim()
+                    )
                 )
             # attn_mask's dim is 3 now.
 
         # convert ByteTensor key_padding_mask to bool
-        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
+        if (
+            key_padding_mask is not None
+            and key_padding_mask.dtype == torch.uint8
+        ):
             warnings.warn(
-                "Byte tensor for key_padding_mask is deprecated. Use bool tensor"
-                " instead."
+                "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
             )
             key_padding_mask = key_padding_mask.to(torch.bool)
 
@@ -1195,10 +1228,14 @@ class RelPositionMultiheadAttention(nn.Module):
         # first compute matrix a and matrix c
         # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
         k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
-        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(
+            q_with_bias_u, k
+        )  # (batch, head, time1, time2)
 
         # compute matrix b and matrix d
-        matrix_bd = torch.matmul(q_with_bias_v, p)  # (batch, head, time1, 2*time1-1)
+        matrix_bd = torch.matmul(
+            q_with_bias_v, p
+        )  # (batch, head, time1, 2*time1-1)
 
         matrix_bd = self.rel_shift(matrix_bd, left_context=left_context)
 
@@ -1206,7 +1243,9 @@ class RelPositionMultiheadAttention(nn.Module):
             matrix_ac + matrix_bd
         ) * scaling  # (batch, head, time1, time2)
 
-        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
+        attn_output_weights = attn_output_weights.view(
+            bsz * num_heads, tgt_len, -1
+        )
 
         assert list(attn_output_weights.size()) == [
             bsz * num_heads,
@@ -1251,7 +1290,9 @@ class RelPositionMultiheadAttention(nn.Module):
             attn_output_weights = attn_output_weights.view(
                 bsz, num_heads, tgt_len, src_len
             )
-            attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0)
+            attn_output_weights = attn_output_weights.masked_fill(
+                combined_mask, 0.0
+            )
             attn_output_weights = attn_output_weights.view(
                 bsz * num_heads, tgt_len, src_len
             )
@@ -1263,9 +1304,13 @@ class RelPositionMultiheadAttention(nn.Module):
         attn_output = torch.bmm(attn_output_weights, v)
         assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
         attn_output = (
-            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+            attn_output.transpose(0, 1)
+            .contiguous()
+            .view(tgt_len, bsz, embed_dim)
+        )
+        attn_output = nn.functional.linear(
+            attn_output, out_proj_weight, out_proj_bias
         )
-        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
 
         if need_weights:
             # average attention weights over heads
@@ -1373,12 +1418,16 @@ class ConvolutionModule(nn.Module):
                 # manualy padding self.lorder zeros to the left
                 x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
             else:
-                assert not self.training, "Cache should be None in training time"
+                assert (
+                    not self.training
+                ), "Cache should be None in training time"
                 assert cache.size(0) == self.lorder
                 x = torch.cat([cache.permute(1, 2, 0), x], dim=2)
                 if right_context > 0:
                     cache = x.permute(2, 0, 1)[
-                        -(self.lorder + right_context) : (-right_context),  # noqa
+                        -(self.lorder + right_context) : (  # noqa
+                            -right_context
+                        ),
                         ...,
                     ]
                 else:
diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py
index 314f49154..74bba9cad 100755
--- a/egs/librispeech/ASR/transducer_stateless/decode.py
+++ b/egs/librispeech/ASR/transducer_stateless/decode.py
@@ -94,19 +94,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=29,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=13,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -174,7 +171,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -232,7 +230,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
 
     hyps = []
 
@@ -248,7 +248,10 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -294,7 +297,11 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -367,7 +374,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -400,7 +409,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -440,7 +450,9 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        params.suffix += (
+            f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        )
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
diff --git a/egs/librispeech/ASR/transducer_stateless/decoder.py b/egs/librispeech/ASR/transducer_stateless/decoder.py
index a182d91e2..fbc2373a9 100644
--- a/egs/librispeech/ASR/transducer_stateless/decoder.py
+++ b/egs/librispeech/ASR/transducer_stateless/decoder.py
@@ -87,7 +87,9 @@ class Decoder(nn.Module):
         if self.context_size > 1:
             embedding_out = embedding_out.permute(0, 2, 1)
             if need_pad is True:
-                embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
+                embedding_out = F.pad(
+                    embedding_out, pad=(self.context_size - 1, 0)
+                )
             else:
                 # During inference time, there is no need to do extra padding
                 # as we only need one output
diff --git a/egs/librispeech/ASR/transducer_stateless/export.py b/egs/librispeech/ASR/transducer_stateless/export.py
index 7c10b4348..8bd0bdea1 100755
--- a/egs/librispeech/ASR/transducer_stateless/export.py
+++ b/egs/librispeech/ASR/transducer_stateless/export.py
@@ -68,20 +68,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=20,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=10,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -112,7 +109,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     return parser
@@ -246,7 +244,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless/joiner.py b/egs/librispeech/ASR/transducer_stateless/joiner.py
index e1625992d..93cccbd8c 100644
--- a/egs/librispeech/ASR/transducer_stateless/joiner.py
+++ b/egs/librispeech/ASR/transducer_stateless/joiner.py
@@ -60,9 +60,13 @@ class Joiner(nn.Module):
         encoder_out_len: List[int] = encoder_out_len.tolist()
         decoder_out_len: List[int] = decoder_out_len.tolist()
 
-        encoder_out_list = [encoder_out[i, : encoder_out_len[i], :] for i in range(N)]
+        encoder_out_list = [
+            encoder_out[i, : encoder_out_len[i], :] for i in range(N)
+        ]
 
-        decoder_out_list = [decoder_out[i, : decoder_out_len[i], :] for i in range(N)]
+        decoder_out_list = [
+            decoder_out[i, : decoder_out_len[i], :] for i in range(N)
+        ]
 
         x = [
             e.unsqueeze(1) + d.unsqueeze(0)
diff --git a/egs/librispeech/ASR/transducer_stateless/pretrained.py b/egs/librispeech/ASR/transducer_stateless/pretrained.py
index bd7eeff28..b64521801 100755
--- a/egs/librispeech/ASR/transducer_stateless/pretrained.py
+++ b/egs/librispeech/ASR/transducer_stateless/pretrained.py
@@ -90,11 +90,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -119,12 +117,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -171,7 +167,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -200,9 +197,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -261,7 +259,9 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -334,7 +334,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py b/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py
index 9af46846a..b00fc34f1 100755
--- a/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py
+++ b/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py
@@ -140,13 +140,16 @@ def main():
                 token_alignment[i, : token_alignment_length[i]].tolist(), sp=sp
             )
             word_starting_time = [
-                "{:.2f}".format(i * frame_shift_in_second) for i in word_starting_frames
+                "{:.2f}".format(i * frame_shift_in_second)
+                for i in word_starting_frames
             ]
 
             words = supervisions["text"][i].split()
 
             assert len(word_starting_frames) == len(words)
-            word_starting_time_dict[cuts[i].id] = list(zip(words, word_starting_time))
+            word_starting_time_dict[cuts[i].id] = list(
+                zip(words, word_starting_time)
+            )
 
         # This is a demo script and we exit here after processing
         # one batch.
@@ -157,7 +160,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless/test_conformer.py b/egs/librispeech/ASR/transducer_stateless/test_conformer.py
index 65b08d425..d1350c8ab 100755
--- a/egs/librispeech/ASR/transducer_stateless/test_conformer.py
+++ b/egs/librispeech/ASR/transducer_stateless/test_conformer.py
@@ -29,7 +29,9 @@ from conformer import Conformer
 
 def test_conformer():
     feature_dim = 50
-    c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4)
+    c = Conformer(
+        num_features=feature_dim, output_dim=256, d_model=128, nhead=4
+    )
     batch_size = 5
     seq_len = 20
     # Just make sure the forward pass runs.
diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py
index bcb883fa5..ae93f3348 100755
--- a/egs/librispeech/ASR/transducer_stateless/train.py
+++ b/egs/librispeech/ASR/transducer_stateless/train.py
@@ -136,7 +136,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -421,7 +422,9 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -542,7 +545,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -659,9 +664,13 @@ def run(rank, world_size, args):
         num_removed = num_in_total - num_left
         removed_percent = num_removed / num_in_total * 100
 
-        logging.info(f"Before removing short and long utterances: {num_in_total}")
+        logging.info(
+            f"Before removing short and long utterances: {num_in_total}"
+        )
         logging.info(f"After removing short and long utterances: {num_left}")
-        logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
+        logging.info(
+            f"Removed {num_removed} utterances ({removed_percent:.5f}%)"
+        )
     except TypeError as e:
         # You can ignore this error as previous versions of Lhotse work fine
         # for the above code. In recent versions of Lhotse, it uses
@@ -689,7 +698,9 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
+            tb_writer.add_scalar(
+                "train/learning_rate", cur_lr, params.batch_idx_train
+            )
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/transducer_stateless/transformer.py b/egs/librispeech/ASR/transducer_stateless/transformer.py
index b3ff153c1..e851dcc32 100644
--- a/egs/librispeech/ASR/transducer_stateless/transformer.py
+++ b/egs/librispeech/ASR/transducer_stateless/transformer.py
@@ -250,7 +250,9 @@ def _get_activation_fn(activation: str):
     elif activation == "gelu":
         return nn.functional.gelu
 
-    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
+    raise RuntimeError(
+        "activation should be relu/gelu, not {}".format(activation)
+    )
 
 
 class PositionalEncoding(nn.Module):
diff --git a/egs/librispeech/ASR/transducer_stateless2/decode.py b/egs/librispeech/ASR/transducer_stateless2/decode.py
index 86ef9e5b6..ac2807241 100755
--- a/egs/librispeech/ASR/transducer_stateless2/decode.py
+++ b/egs/librispeech/ASR/transducer_stateless2/decode.py
@@ -94,19 +94,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=29,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=13,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -174,7 +171,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -232,7 +230,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
 
     hyps = []
 
@@ -248,7 +248,10 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -294,7 +297,11 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -367,7 +374,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -400,7 +409,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -440,7 +450,9 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        params.suffix += (
+            f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        )
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
diff --git a/egs/librispeech/ASR/transducer_stateless2/export.py b/egs/librispeech/ASR/transducer_stateless2/export.py
index d95eeb1f4..57c1a6094 100755
--- a/egs/librispeech/ASR/transducer_stateless2/export.py
+++ b/egs/librispeech/ASR/transducer_stateless2/export.py
@@ -63,20 +63,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=20,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=10,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -107,7 +104,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     return parser
@@ -178,7 +176,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless2/pretrained.py b/egs/librispeech/ASR/transducer_stateless2/pretrained.py
index 793931e3b..292f77f03 100755
--- a/egs/librispeech/ASR/transducer_stateless2/pretrained.py
+++ b/egs/librispeech/ASR/transducer_stateless2/pretrained.py
@@ -90,11 +90,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -119,12 +117,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -171,7 +167,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -200,9 +197,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -261,7 +259,9 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -334,7 +334,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless2/train.py b/egs/librispeech/ASR/transducer_stateless2/train.py
index 68e247f23..ea15c9040 100755
--- a/egs/librispeech/ASR/transducer_stateless2/train.py
+++ b/egs/librispeech/ASR/transducer_stateless2/train.py
@@ -136,7 +136,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -409,7 +410,9 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -530,7 +533,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -647,9 +652,13 @@ def run(rank, world_size, args):
         num_removed = num_in_total - num_left
         removed_percent = num_removed / num_in_total * 100
 
-        logging.info(f"Before removing short and long utterances: {num_in_total}")
+        logging.info(
+            f"Before removing short and long utterances: {num_in_total}"
+        )
         logging.info(f"After removing short and long utterances: {num_left}")
-        logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
+        logging.info(
+            f"Removed {num_removed} utterances ({removed_percent:.5f}%)"
+        )
     except TypeError as e:
         # You can ignore this error as previous versions of Lhotse work fine
         # for the above code. In recent versions of Lhotse, it uses
@@ -677,7 +686,9 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
+            tb_writer.add_scalar(
+                "train/learning_rate", cur_lr, params.batch_idx_train
+            )
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py
index 22b6ab911..d596e05cb 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py
@@ -95,19 +95,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=29,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=13,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -175,7 +172,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -233,7 +231,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
 
     hyps = []
 
@@ -249,7 +249,10 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -295,7 +298,11 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -368,7 +375,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -401,7 +410,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -441,7 +451,9 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        params.suffix += (
+            f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        )
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py
index fad9a6977..b6b69d932 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py
@@ -69,20 +69,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=20,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=10,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -113,7 +110,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     return parser
@@ -249,7 +247,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py
index efd257b5d..f297fa2b2 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py
@@ -90,11 +90,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -119,12 +117,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -171,7 +167,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -200,9 +197,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -261,7 +259,9 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -334,7 +334,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py
index 1e1188ca6..ef51a7811 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py
@@ -41,7 +41,9 @@ def test_dataset():
     print(args)
 
     if args.enable_musan:
-        cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz")
+        cuts_musan = load_manifest(
+            Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
+        )
     else:
         cuts_musan = None
 
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py
index 88987d91c..27912738c 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py
@@ -114,7 +114,8 @@ def get_parser():
         "--full-libri",
         type=str2bool,
         default=True,
-        help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.",
+        help="When enabled, use 960h LibriSpeech. "
+        "Otherwise, use 100h subset.",
     )
 
     parser.add_argument(
@@ -169,7 +170,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -467,7 +469,9 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -631,7 +635,9 @@ def train_one_epoch(
                     f"train/current_{prefix}_",
                     params.batch_idx_train,
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
                 libri_tot_loss.write_summary(
                     tb_writer, "train/libri_tot_", params.batch_idx_train
                 )
@@ -778,7 +784,9 @@ def run(rank, world_size, args):
     train_giga_cuts = train_giga_cuts.repeat(times=None)
 
     if args.enable_musan:
-        cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz")
+        cuts_musan = load_manifest(
+            Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
+        )
     else:
         cuts_musan = None
 
@@ -817,7 +825,9 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
+            tb_writer.add_scalar(
+                "train/learning_rate", cur_lr, params.batch_idx_train
+            )
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/ptb/LM/local/sort_lm_training_data.py b/egs/ptb/LM/local/sort_lm_training_data.py
index bed3856e4..af54dbd07 100755
--- a/egs/ptb/LM/local/sort_lm_training_data.py
+++ b/egs/ptb/LM/local/sort_lm_training_data.py
@@ -135,7 +135,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/ptb/LM/local/test_prepare_lm_training_data.py b/egs/ptb/LM/local/test_prepare_lm_training_data.py
index 3790045fa..877720e7b 100755
--- a/egs/ptb/LM/local/test_prepare_lm_training_data.py
+++ b/egs/ptb/LM/local/test_prepare_lm_training_data.py
@@ -54,7 +54,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/spgispeech/ASR/local/compute_fbank_musan.py b/egs/spgispeech/ASR/local/compute_fbank_musan.py
index 9bea28a41..6cb8b65ae 100755
--- a/egs/spgispeech/ASR/local/compute_fbank_musan.py
+++ b/egs/spgispeech/ASR/local/compute_fbank_musan.py
@@ -87,7 +87,9 @@ def compute_fbank_musan():
     # create chunks of Musan with duration 5 - 10 seconds
     musan_cuts = (
         CutSet.from_manifests(
-            recordings=combine(part["recordings"] for part in manifests.values())
+            recordings=combine(
+                part["recordings"] for part in manifests.values()
+            )
         )
         .cut_into_windows(10.0)
         .filter(lambda c: c.duration > 5)
@@ -106,6 +108,8 @@ def compute_fbank_musan():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
     logging.basicConfig(format=formatter, level=logging.INFO)
     compute_fbank_musan()
diff --git a/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py b/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py
index 20ff6d7ab..8116e7605 100755
--- a/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py
+++ b/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py
@@ -103,7 +103,11 @@ def compute_fbank_spgispeech(args):
             chunk_size=chunk_size,
         )
         start = args.start
-        stop = min(args.stop, args.num_splits) if args.stop > 0 else args.num_splits
+        stop = (
+            min(args.stop, args.num_splits)
+            if args.stop > 0
+            else args.num_splits
+        )
         num_digits = len(str(args.num_splits))
         for i in range(start, stop):
             idx = f"{i + 1}".zfill(num_digits)
@@ -125,7 +129,9 @@ def compute_fbank_spgispeech(args):
                 logging.info(f"{partition} already exists - skipping.")
                 continue
             logging.info(f"Processing {partition}")
-            cut_set = load_manifest_lazy(src_dir / f"cuts_{partition}_raw.jsonl.gz")
+            cut_set = load_manifest_lazy(
+                src_dir / f"cuts_{partition}_raw.jsonl.gz"
+            )
             cut_set = cut_set.compute_and_store_features_batch(
                 extractor=extractor,
                 storage_path=output_dir / f"feats_{partition}",
@@ -138,7 +144,9 @@ def compute_fbank_spgispeech(args):
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
     logging.basicConfig(format=formatter, level=logging.INFO)
 
     args = get_args()
diff --git a/egs/spgispeech/ASR/local/prepare_splits.py b/egs/spgispeech/ASR/local/prepare_splits.py
index 508d4acd8..8c8f1c133 100755
--- a/egs/spgispeech/ASR/local/prepare_splits.py
+++ b/egs/spgispeech/ASR/local/prepare_splits.py
@@ -55,7 +55,9 @@ def split_spgispeech_train():
 
     # Add speed perturbation
     train_cuts = (
-        train_cuts + train_cuts.perturb_speed(0.9) + train_cuts.perturb_speed(1.1)
+        train_cuts
+        + train_cuts.perturb_speed(0.9)
+        + train_cuts.perturb_speed(1.1)
     )
 
     # Write the manifests to disk.
@@ -71,7 +73,9 @@ def split_spgispeech_train():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
     logging.basicConfig(format=formatter, level=logging.INFO)
 
     split_spgispeech_train()
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
index 83f95d123..f165f6e60 100644
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -70,12 +70,10 @@ class SPGISpeechAsrDataModule:
     def add_arguments(cls, parser: argparse.ArgumentParser):
         group = parser.add_argument_group(
             title="ASR data related options",
-            description=(
-                "These options are used for the preparation of "
-                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-                "effective batch sizes, sampling strategies, applied data "
-                "augmentations, etc."
-            ),
+            description="These options are used for the preparation of "
+            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+            "effective batch sizes, sampling strategies, applied data "
+            "augmentations, etc.",
         )
         group.add_argument(
             "--manifest-dir",
@@ -87,81 +85,67 @@ class SPGISpeechAsrDataModule:
             "--enable-musan",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, select noise from MUSAN and mix it "
-                "with training dataset. "
-            ),
+            help="When enabled, select noise from MUSAN and mix it "
+            "with training dataset. ",
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, utterances (cuts) will be concatenated "
-                "to minimize the amount of padding."
-            ),
+            help="When enabled, utterances (cuts) will be concatenated "
+            "to minimize the amount of padding.",
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help=(
-                "Determines the maximum duration of a concatenated cut "
-                "relative to the duration of the longest cut in a batch."
-            ),
+            help="Determines the maximum duration of a concatenated cut "
+            "relative to the duration of the longest cut in a batch.",
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help=(
-                "The amount of padding (in seconds) inserted between "
-                "concatenated cuts. This padding is filled with noise when "
-                "noise augmentation is used."
-            ),
+            help="The amount of padding (in seconds) inserted between "
+            "concatenated cuts. This padding is filled with noise when "
+            "noise augmentation is used.",
         )
         group.add_argument(
             "--max-duration",
             type=int,
             default=100.0,
-            help=(
-                "Maximum pooled recordings duration (seconds) in a "
-                "single batch. You can reduce it if it causes CUDA OOM."
-            ),
+            help="Maximum pooled recordings duration (seconds) in a "
+            "single batch. You can reduce it if it causes CUDA OOM.",
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=30,
-            help=(
-                "The number of buckets for the BucketingSampler"
-                "(you might want to increase it for larger datasets)."
-            ),
+            help="The number of buckets for the BucketingSampler"
+            "(you might want to increase it for larger datasets).",
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, use on-the-fly cut mixing and feature "
-                "extraction. Will drop existing precomputed feature manifests "
-                "if available."
-            ),
+            help="When enabled, use on-the-fly cut mixing and feature "
+            "extraction. Will drop existing precomputed feature manifests "
+            "if available.",
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled (=default), the examples will be shuffled for each epoch."
-            ),
+            help="When enabled (=default), the examples will be "
+            "shuffled for each epoch.",
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=8,
-            help="The number of training dataloader workers that collect the batches.",
+            help="The number of training dataloader workers that "
+            "collect the batches.",
         )
         group.add_argument(
             "--enable-spec-aug",
@@ -173,12 +157,10 @@ class SPGISpeechAsrDataModule:
             "--spec-aug-time-warp-factor",
             type=int,
             default=80,
-            help=(
-                "Used only when --enable-spec-aug is True. "
-                "It specifies the factor for time warping in SpecAugment. "
-                "Larger values mean more warping. "
-                "A value less than 1 means to disable time warp."
-            ),
+            help="Used only when --enable-spec-aug is True. "
+            "It specifies the factor for time warping in SpecAugment. "
+            "Larger values mean more warping. "
+            "A value less than 1 means to disable time warp.",
         )
 
     def train_dataloaders(
@@ -194,20 +176,24 @@ class SPGISpeechAsrDataModule:
             The state dict for the training sampler.
         """
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.jsonl.gz")
+        cuts_musan = load_manifest(
+            self.args.manifest_dir / "cuts_musan.jsonl.gz"
+        )
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+                CutMix(
+                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
+                )
             )
         else:
             logging.info("Disable MUSAN")
 
         if self.args.concatenate_cuts:
             logging.info(
-                "Using cut concatenation with duration factor "
+                f"Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -222,7 +208,9 @@ class SPGISpeechAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+            logging.info(
+                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
+            )
             input_transforms.append(
                 SpecAugment(
                     time_warp_factor=self.args.spec_aug_time_warp_factor,
@@ -239,7 +227,9 @@ class SPGISpeechAsrDataModule:
         if self.args.on_the_fly_feats:
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 input_transforms=input_transforms,
             )
         else:
@@ -292,7 +282,9 @@ class SPGISpeechAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
             )
         else:
             validate = K2SpeechRecognitionDataset(
@@ -336,7 +328,9 @@ class SPGISpeechAsrDataModule:
     @lru_cache()
     def train_cuts(self) -> CutSet:
         logging.info("About to get SPGISpeech train cuts")
-        return load_manifest_lazy(self.args.manifest_dir / "cuts_train_shuf.jsonl.gz")
+        return load_manifest_lazy(
+            self.args.manifest_dir / "cuts_train_shuf.jsonl.gz"
+        )
 
     @lru_cache()
     def dev_cuts(self) -> CutSet:
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py
index 72a7cd1c1..c39bd0530 100755
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py
@@ -76,7 +76,11 @@ from beam_search import (
 )
 from train import get_params, get_transducer_model
 
-from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
+from icefall.checkpoint import (
+    average_checkpoints,
+    find_checkpoints,
+    load_checkpoint,
+)
 from icefall.utils import (
     AttributeDict,
     setup_logger,
@@ -113,11 +117,9 @@ def get_parser():
         "--avg",
         type=int,
         default=10,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch' and '--iter'"
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
     )
 
     parser.add_argument(
@@ -185,7 +187,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -243,7 +246,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -258,7 +263,10 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -304,7 +312,11 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -377,7 +389,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -410,7 +424,9 @@ def save_results(
         # we also compute CER for spgispeech dataset.
         results_char = []
         for res in results:
-            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
+            results_char.append(
+                (res[0], list("".join(res[1])), list("".join(res[2])))
+            )
         cers_filename = (
             params.res_dir / f"cers-{test_set_name}-{key}-{params.suffix}.txt"
         )
@@ -422,23 +438,32 @@ def save_results(
 
         logging.info("Wrote detailed error stats to {}".format(wers_filename))
 
-    test_set_wers = {k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])}
-    test_set_cers = {k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])}
+    test_set_wers = {
+        k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])
+    }
+    test_set_cers = {
+        k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])
+    }
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER\tCER", file=f)
         for key in test_set_wers:
             print(
-                "{}\t{}\t{}".format(key, test_set_wers[key], test_set_cers[key]),
+                "{}\t{}\t{}".format(
+                    key, test_set_wers[key], test_set_cers[key]
+                ),
                 file=f,
             )
 
     s = "\nFor {}, WER/CER of different settings are:\n".format(test_set_name)
     note = "\tbest for {}".format(test_set_name)
     for key in test_set_wers:
-        s += "{}\t{}\t{}{}\n".format(key, test_set_wers[key], test_set_cers[key], note)
+        s += "{}\t{}\t{}{}\n".format(
+            key, test_set_wers[key], test_set_cers[key], note
+        )
         note = ""
     logging.info(s)
 
@@ -471,7 +496,9 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        params.suffix += (
+            f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        )
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@@ -503,7 +530,8 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for"
+                f" --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py
index 1f18ae2f3..77faa3c0e 100755
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py
@@ -50,7 +50,11 @@ import sentencepiece as spm
 import torch
 from train import get_params, get_transducer_model
 
-from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
+from icefall.checkpoint import (
+    average_checkpoints,
+    find_checkpoints,
+    load_checkpoint,
+)
 from icefall.utils import str2bool
 
 
@@ -63,20 +67,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=28,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -118,7 +119,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     return parser
@@ -194,7 +196,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py
index cd835a7b4..dda29b3e5 100755
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py
@@ -77,7 +77,9 @@ from icefall.dist import cleanup_dist, setup_dist
 from icefall.env import get_env_info
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+LRSchedulerType = Union[
+    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
+]
 
 
 def get_parser():
@@ -153,7 +155,8 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need to be changed.",
+        help="The initial learning rate.  This value should not need to be "
+        "changed.",
     )
 
     parser.add_argument(
@@ -176,45 +179,42 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
         "--prune-range",
         type=int,
         default=5,
-        help=(
-            "The prune range for rnnt loss, it means how many symbols(context)"
-            "we are using to compute the loss"
-        ),
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
     )
 
     parser.add_argument(
         "--lm-scale",
         type=float,
         default=0.25,
-        help=(
-            "The scale to smooth the loss with lm (output of prediction network) part."
-        ),
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
     )
 
     parser.add_argument(
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)part.",
+        help="The scale to smooth the loss with am (output of encoder network)"
+        "part.",
     )
 
     parser.add_argument(
         "--simple-loss-scale",
         type=float,
         default=0.5,
-        help=(
-            "To get pruning ranges, we will calculate a simple version"
-            "loss(joiner is just addition), this simple loss also uses for"
-            "training (as a regularization item). We will scale the simple loss"
-            "with this parameter before adding to the final loss."
-        ),
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
     )
 
     parser.add_argument(
@@ -554,16 +554,23 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+            0.0
+            if warmup < 1.0
+            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+        )
+        loss = (
+            params.simple_loss_scale * simple_loss
+            + pruned_loss_scale * pruned_loss
         )
-        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
 
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -726,7 +733,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
diff --git a/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py b/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py
index 602e50d29..4582609ac 100755
--- a/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py
+++ b/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py
@@ -84,7 +84,9 @@ def compute_fbank_tal_csasr(num_mel_bins: int = 80):
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+                    cut_set
+                    + cut_set.perturb_speed(0.9)
+                    + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -110,7 +112,9 @@ def get_args():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/tal_csasr/ASR/local/prepare_char.py b/egs/tal_csasr/ASR/local/prepare_char.py
index 1262baf63..2c5b8b8b3 100755
--- a/egs/tal_csasr/ASR/local/prepare_char.py
+++ b/egs/tal_csasr/ASR/local/prepare_char.py
@@ -87,7 +87,9 @@ def lexicon_to_fst_no_sil(
         cur_state = loop_state
 
         word = word2id[word]
-        pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces]
+        pieces = [
+            token2id[i] if i in token2id else token2id[""] for i in pieces
+        ]
 
         for i in range(len(pieces) - 1):
             w = word if i == 0 else eps
diff --git a/egs/tal_csasr/ASR/local/prepare_lang.py b/egs/tal_csasr/ASR/local/prepare_lang.py
index c8cf9b881..e5ae89ec4 100755
--- a/egs/tal_csasr/ASR/local/prepare_lang.py
+++ b/egs/tal_csasr/ASR/local/prepare_lang.py
@@ -317,7 +317,9 @@ def lexicon_to_fst(
 
 def get_args():
     parser = argparse.ArgumentParser()
-    parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
+    parser.add_argument(
+        "--lang-dir", type=str, help="The lang dir, data/lang_phone"
+    )
     return parser.parse_args()
 
 
diff --git a/egs/tal_csasr/ASR/local/test_prepare_lang.py b/egs/tal_csasr/ASR/local/test_prepare_lang.py
index 74e025ad7..d4cf62bba 100755
--- a/egs/tal_csasr/ASR/local/test_prepare_lang.py
+++ b/egs/tal_csasr/ASR/local/test_prepare_lang.py
@@ -88,7 +88,9 @@ def test_read_lexicon(filename: str):
     fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa.draw("L.pdf", title="L")
 
-    fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
+    fsa_disambig = lexicon_to_fst(
+        lexicon_disambig, phone2id=phone2id, word2id=word2id
+    )
     fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
     fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
diff --git a/egs/tal_csasr/ASR/local/text2token.py b/egs/tal_csasr/ASR/local/text2token.py
index 2be639b7a..71be2a613 100755
--- a/egs/tal_csasr/ASR/local/text2token.py
+++ b/egs/tal_csasr/ASR/local/text2token.py
@@ -50,15 +50,15 @@ def get_parser():
         "-n",
         default=1,
         type=int,
-        help=(
-            "number of characters to split, i.e.,                         aabb -> a a b"
-            " b with -n 1 and aa bb with -n 2"
-        ),
+        help="number of characters to split, i.e., \
+                        aabb -> a a b b with -n 1 and aa bb with -n 2",
     )
     parser.add_argument(
         "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
     )
-    parser.add_argument("--space", default="", type=str, help="space symbol")
+    parser.add_argument(
+        "--space", default="", type=str, help="space symbol"
+    )
     parser.add_argument(
         "--non-lang-syms",
         "-l",
@@ -66,7 +66,9 @@ def get_parser():
         type=str,
         help="list of non-linguistic symobles, e.g.,  etc.",
     )
-    parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
+    parser.add_argument(
+        "text", type=str, default=False, nargs="?", help="input text"
+    )
     parser.add_argument(
         "--trans_type",
         "-t",
@@ -106,7 +108,8 @@ def token2id(
             if token_type == "lazy_pinyin":
                 text = lazy_pinyin(chars_list)
                 sub_ids = [
-                    token_table[txt] if txt in token_table else oov_id for txt in text
+                    token_table[txt] if txt in token_table else oov_id
+                    for txt in text
                 ]
                 ids.append(sub_ids)
             else:  # token_type = "pinyin"
@@ -132,7 +135,9 @@ def main():
     if args.text:
         f = codecs.open(args.text, encoding="utf-8")
     else:
-        f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
+        f = codecs.getreader("utf-8")(
+            sys.stdin if is_python2 else sys.stdin.buffer
+        )
 
     sys.stdout = codecs.getwriter("utf-8")(
         sys.stdout if is_python2 else sys.stdout.buffer
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py
index 02bd6e2cc..49bfb148b 100644
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py
@@ -74,12 +74,10 @@ class TAL_CSASRAsrDataModule:
     def add_arguments(cls, parser: argparse.ArgumentParser):
         group = parser.add_argument_group(
             title="ASR data related options",
-            description=(
-                "These options are used for the preparation of "
-                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-                "effective batch sizes, sampling strategies, applied data "
-                "augmentations, etc."
-            ),
+            description="These options are used for the preparation of "
+            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+            "effective batch sizes, sampling strategies, applied data "
+            "augmentations, etc.",
         )
 
         group.add_argument(
@@ -93,81 +91,66 @@ class TAL_CSASRAsrDataModule:
             "--max-duration",
             type=int,
             default=200.0,
-            help=(
-                "Maximum pooled recordings duration (seconds) in a "
-                "single batch. You can reduce it if it causes CUDA OOM."
-            ),
+            help="Maximum pooled recordings duration (seconds) in a "
+            "single batch. You can reduce it if it causes CUDA OOM.",
         )
 
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, the batches will come from buckets of "
-                "similar duration (saves padding frames)."
-            ),
+            help="When enabled, the batches will come from buckets of "
+            "similar duration (saves padding frames).",
         )
 
         group.add_argument(
             "--num-buckets",
             type=int,
             default=300,
-            help=(
-                "The number of buckets for the DynamicBucketingSampler"
-                "(you might want to increase it for larger datasets)."
-            ),
+            help="The number of buckets for the DynamicBucketingSampler"
+            "(you might want to increase it for larger datasets).",
         )
 
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, utterances (cuts) will be concatenated "
-                "to minimize the amount of padding."
-            ),
+            help="When enabled, utterances (cuts) will be concatenated "
+            "to minimize the amount of padding.",
         )
 
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help=(
-                "Determines the maximum duration of a concatenated cut "
-                "relative to the duration of the longest cut in a batch."
-            ),
+            help="Determines the maximum duration of a concatenated cut "
+            "relative to the duration of the longest cut in a batch.",
         )
 
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help=(
-                "The amount of padding (in seconds) inserted between "
-                "concatenated cuts. This padding is filled with noise when "
-                "noise augmentation is used."
-            ),
+            help="The amount of padding (in seconds) inserted between "
+            "concatenated cuts. This padding is filled with noise when "
+            "noise augmentation is used.",
         )
 
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, use on-the-fly cut mixing and feature "
-                "extraction. Will drop existing precomputed feature manifests "
-                "if available."
-            ),
+            help="When enabled, use on-the-fly cut mixing and feature "
+            "extraction. Will drop existing precomputed feature manifests "
+            "if available.",
         )
 
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled (=default), the examples will be shuffled for each epoch."
-            ),
+            help="When enabled (=default), the examples will be "
+            "shuffled for each epoch.",
         )
 
         group.add_argument(
@@ -181,18 +164,17 @@ class TAL_CSASRAsrDataModule:
             "--return-cuts",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, each batch will have the "
-                "field: batch['supervisions']['cut'] with the cuts that "
-                "were used to construct it."
-            ),
+            help="When enabled, each batch will have the "
+            "field: batch['supervisions']['cut'] with the cuts that "
+            "were used to construct it.",
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that collect the batches.",
+            help="The number of training dataloader workers that "
+            "collect the batches.",
         )
 
         group.add_argument(
@@ -206,22 +188,18 @@ class TAL_CSASRAsrDataModule:
             "--spec-aug-time-warp-factor",
             type=int,
             default=80,
-            help=(
-                "Used only when --enable-spec-aug is True. "
-                "It specifies the factor for time warping in SpecAugment. "
-                "Larger values mean more warping. "
-                "A value less than 1 means to disable time warp."
-            ),
+            help="Used only when --enable-spec-aug is True. "
+            "It specifies the factor for time warping in SpecAugment. "
+            "Larger values mean more warping. "
+            "A value less than 1 means to disable time warp.",
         )
 
         group.add_argument(
             "--enable-musan",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, select noise from MUSAN and mix it"
-                "with training dataset. "
-            ),
+            help="When enabled, select noise from MUSAN and mix it"
+            "with training dataset. ",
         )
 
         group.add_argument(
@@ -244,20 +222,24 @@ class TAL_CSASRAsrDataModule:
             The state dict for the training sampler.
         """
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
+        cuts_musan = load_manifest(
+            self.args.manifest_dir / "musan_cuts.jsonl.gz"
+        )
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+                CutMix(
+                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
+                )
             )
         else:
             logging.info("Disable MUSAN")
 
         if self.args.concatenate_cuts:
             logging.info(
-                "Using cut concatenation with duration factor "
+                f"Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -272,7 +254,9 @@ class TAL_CSASRAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+            logging.info(
+                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
+            )
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -316,7 +300,9 @@ class TAL_CSASRAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -374,7 +360,9 @@ class TAL_CSASRAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 return_cuts=self.args.return_cuts,
             )
         else:
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py
index b2aef7e86..b624913f5 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py
@@ -124,24 +124,20 @@ def get_parser():
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch' and '--iter'"
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
     )
 
     parser.add_argument(
         "--use-averaged-model",
         type=str2bool,
         default=False,
-        help=(
-            "Whether to load averaged model. Currently it only supports "
-            "using --epoch. If True, it would decode with the averaged model "
-            "over the epoch range from `epoch-avg` (excluded) to `epoch`."
-            "Actually only the models with epoch number of `epoch-avg` and "
-            "`epoch` are loaded for averaging. "
-        ),
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
     )
 
     parser.add_argument(
@@ -212,7 +208,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -271,7 +268,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
     hyps = []
     zh_hyps = []
     en_hyps = []
@@ -304,7 +303,10 @@ def decode_one_batch(
             hyps.append(chars_new)
             zh_hyps.append(zh_text)
             en_hyps.append(en_text)
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -373,7 +375,9 @@ def decode_one_batch(
                     f"Unsupported decoding method: {params.decoding_method}"
                 )
             for i in range(encoder_out.size(0)):
-                hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+                hyp = sp.decode(
+                    [lexicon.token_table[idx] for idx in hyp_tokens[i]]
+                )
                 chars = pattern.split(hyp.upper())
                 chars_new = []
                 zh_text = []
@@ -392,11 +396,11 @@ def decode_one_batch(
         return {"greedy_search": (hyps, zh_hyps, en_hyps)}
     elif params.decoding_method == "fast_beam_search":
         return {
-            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": (
-                hyps,
-                zh_hyps,
-                en_hyps,
-            )
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): (hyps, zh_hyps, en_hyps)
         }
     else:
         return {f"beam_size_{params.beam_size}": (hyps, zh_hyps, en_hyps)}
@@ -502,7 +506,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results, zh_results, en_results
 
 
@@ -535,7 +541,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -578,7 +585,9 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        params.suffix += (
+            f"-{params.decoding_method}-beam-size-{params.beam_size}"
+        )
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@@ -610,12 +619,13 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg:
                 raise ValueError(
@@ -638,12 +648,13 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg + 1
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg + 1]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg + 1:
                 raise ValueError(
@@ -671,7 +682,7 @@ def main():
             filename_start = f"{params.exp_dir}/epoch-{start}.pt"
             filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
             logging.info(
-                "Calculating the averaged model over epoch range from "
+                f"Calculating the averaged model over epoch range from "
                 f"{start} (excluded) to {params.epoch}"
             )
             model.to(device)
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py
index 94a4c7a2e..8f900208a 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py
@@ -92,24 +92,20 @@ def get_parser():
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch' and '--iter'"
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
     )
 
     parser.add_argument(
         "--use-averaged-model",
         type=str2bool,
         default=False,
-        help=(
-            "Whether to load averaged model. Currently it only supports "
-            "using --epoch. If True, it would decode with the averaged model "
-            "over the epoch range from `epoch-avg` (excluded) to `epoch`."
-            "Actually only the models with epoch number of `epoch-avg` and "
-            "`epoch` are loaded for averaging. "
-        ),
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
     )
 
     parser.add_argument(
@@ -143,7 +139,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     add_model_arguments(parser)
@@ -179,12 +176,13 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg:
                 raise ValueError(
@@ -207,12 +205,13 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg + 1
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg + 1]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg + 1:
                 raise ValueError(
@@ -240,7 +239,7 @@ def main():
             filename_start = f"{params.exp_dir}/epoch-{start}.pt"
             filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
             logging.info(
-                "Calculating the averaged model over epoch range from "
+                f"Calculating the averaged model over epoch range from "
                 f"{start} (excluded) to {params.epoch}"
             )
             model.to(device)
@@ -278,7 +277,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py
index 198242129..dbe213b24 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py
@@ -84,11 +84,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -117,12 +115,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -169,7 +165,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -200,9 +197,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -265,11 +263,15 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=features, x_lens=feature_lengths
+    )
 
     num_waves = encoder_out.size(0)
     hyps = []
@@ -365,7 +367,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py
index 676e8c904..ca35eba45 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py
@@ -86,7 +86,9 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+LRSchedulerType = Union[
+    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
+]
 
 
 def add_model_arguments(parser: argparse.ArgumentParser):
@@ -212,7 +214,8 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need to be changed.",
+        help="The initial learning rate.  This value should not need "
+        "to be changed.",
     )
 
     parser.add_argument(
@@ -235,45 +238,42 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
         "--prune-range",
         type=int,
         default=5,
-        help=(
-            "The prune range for rnnt loss, it means how many symbols(context)"
-            "we are using to compute the loss"
-        ),
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
     )
 
     parser.add_argument(
         "--lm-scale",
         type=float,
         default=0.25,
-        help=(
-            "The scale to smooth the loss with lm (output of prediction network) part."
-        ),
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
     )
 
     parser.add_argument(
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)part.",
+        help="The scale to smooth the loss with am (output of encoder network)"
+        "part.",
     )
 
     parser.add_argument(
         "--simple-loss-scale",
         type=float,
         default=0.5,
-        help=(
-            "To get pruning ranges, we will calculate a simple version"
-            "loss(joiner is just addition), this simple loss also uses for"
-            "training (as a regularization item). We will scale the simple loss"
-            "with this parameter before adding to the final loss."
-        ),
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
     )
 
     parser.add_argument(
@@ -600,7 +600,11 @@ def compute_loss(
      warmup: a floating point value which increases throughout training;
         values >= 1.0 are fully warmed up and have all modules present.
     """
-    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+    device = (
+        model.device
+        if isinstance(model, DDP)
+        else next(model.parameters()).device
+    )
     feature = batch["inputs"]
     # at entry, feature is (N, T, C)
     assert feature.ndim == 3
@@ -630,15 +634,22 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+            0.0
+            if warmup < 1.0
+            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+        )
+        loss = (
+            params.simple_loss_scale * simple_loss
+            + pruned_loss_scale * pruned_loss
         )
-        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -817,7 +828,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -931,7 +944,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2**22
+            2 ** 22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
diff --git a/egs/tedlium3/ASR/local/compute_fbank_tedlium.py b/egs/tedlium3/ASR/local/compute_fbank_tedlium.py
index 733ebf235..327962a79 100755
--- a/egs/tedlium3/ASR/local/compute_fbank_tedlium.py
+++ b/egs/tedlium3/ASR/local/compute_fbank_tedlium.py
@@ -83,7 +83,9 @@ def compute_fbank_tedlium():
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+                    cut_set
+                    + cut_set.perturb_speed(0.9)
+                    + cut_set.perturb_speed(1.1)
                 )
             cur_num_jobs = num_jobs if ex is None else 80
             cur_num_jobs = min(cur_num_jobs, len(cut_set))
@@ -102,7 +104,9 @@ def compute_fbank_tedlium():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py b/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py
index 9dbcc9d9e..49544ccb3 100644
--- a/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py
+++ b/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py
@@ -25,7 +25,9 @@ import sentencepiece as spm
 
 def get_args():
     parser = argparse.ArgumentParser()
-    parser.add_argument("--texts", type=List[str], help="The input transcripts list.")
+    parser.add_argument(
+        "--texts", type=List[str], help="The input transcripts list."
+    )
     parser.add_argument(
         "--bpe-model",
         type=str,
diff --git a/egs/tedlium3/ASR/local/prepare_lexicon.py b/egs/tedlium3/ASR/local/prepare_lexicon.py
index b9160b6d4..35dd332e8 100755
--- a/egs/tedlium3/ASR/local/prepare_lexicon.py
+++ b/egs/tedlium3/ASR/local/prepare_lexicon.py
@@ -23,12 +23,11 @@ consisting of supervisions_train.json and does the following:
 1. Generate lexicon_words.txt.
 
 """
+import lhotse
 import argparse
 import logging
 from pathlib import Path
 
-import lhotse
-
 
 def get_args():
     parser = argparse.ArgumentParser()
@@ -62,7 +61,9 @@ def prepare_lexicon(manifests_dir: str, lang_dir: str):
     words = set()
 
     lexicon = Path(lang_dir) / "lexicon_words.txt"
-    sups = lhotse.load_manifest(f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz")
+    sups = lhotse.load_manifest(
+        f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz"
+    )
     for s in sups:
         # list the words units and filter the empty item
         words_list = list(filter(None, s.text.split()))
@@ -87,7 +88,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/tedlium3/ASR/local/prepare_transcripts.py b/egs/tedlium3/ASR/local/prepare_transcripts.py
index 7ea4e89a4..1039ac5bb 100755
--- a/egs/tedlium3/ASR/local/prepare_transcripts.py
+++ b/egs/tedlium3/ASR/local/prepare_transcripts.py
@@ -23,12 +23,11 @@ consisting of supervisions_train.json and does the following:
 1. Generate train.text.
 
 """
+import lhotse
 import argparse
 import logging
 from pathlib import Path
 
-import lhotse
-
 
 def get_args():
     parser = argparse.ArgumentParser()
@@ -62,7 +61,9 @@ def prepare_transcripts(manifests_dir: str, lang_dir: str):
     texts = []
 
     train_text = Path(lang_dir) / "train.text"
-    sups = lhotse.load_manifest(f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz")
+    sups = lhotse.load_manifest(
+        f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz"
+    )
     for s in sups:
         texts.append(s.text)
 
@@ -82,7 +83,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py
index 6bae33e65..2b294e601 100755
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py
@@ -94,20 +94,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=29,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=13,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -175,7 +172,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -233,7 +231,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -248,7 +248,10 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -294,7 +297,11 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -367,7 +374,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -400,7 +409,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/export.py b/egs/tedlium3/ASR/pruned_transducer_stateless/export.py
index 244740932..a1c3bcea3 100644
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/export.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/export.py
@@ -65,20 +65,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=30,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=13,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -109,7 +106,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     return parser
@@ -181,7 +179,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py
index 00545f107..8480ac029 100644
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py
@@ -93,11 +93,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -124,12 +122,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -169,7 +165,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -206,9 +203,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -273,7 +271,9 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -298,7 +298,10 @@ def main():
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -350,7 +353,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/train.py b/egs/tedlium3/ASR/pruned_transducer_stateless/train.py
index 70c5e290f..8d5cdf683 100755
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/train.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/train.py
@@ -133,45 +133,42 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
         "--prune-range",
         type=int,
         default=5,
-        help=(
-            "The prune range for rnnt loss, it means how many symbols(context)"
-            "we are using to compute the loss"
-        ),
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
     )
 
     parser.add_argument(
         "--lm-scale",
         type=float,
         default=0.25,
-        help=(
-            "The scale to smooth the loss with lm (output of prediction network) part."
-        ),
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
     )
 
     parser.add_argument(
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)part.",
+        help="The scale to smooth the loss with am (output of encoder network)"
+        "part.",
     )
 
     parser.add_argument(
         "--simple-loss-scale",
         type=float,
         default=0.5,
-        help=(
-            "To get pruning ranges, we will calculate a simple version"
-            "loss(joiner is just addition), this simple loss also uses for"
-            "training (as a regularization item). We will scale the simple loss"
-            "with this parameter before adding to the final loss."
-        ),
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
     )
 
     parser.add_argument(
@@ -559,7 +556,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -679,7 +678,9 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
+            tb_writer.add_scalar(
+                "train/learning_rate", cur_lr, params.batch_idx_train
+            )
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py
index f90f79d8c..94784c4c4 100644
--- a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py
+++ b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py
@@ -18,6 +18,7 @@
 
 import argparse
 import logging
+
 from functools import lru_cache
 from pathlib import Path
 from typing import Any, Dict, Optional
@@ -62,12 +63,10 @@ class TedLiumAsrDataModule:
     def add_arguments(cls, parser: argparse.ArgumentParser):
         group = parser.add_argument_group(
             title="ASR data related options",
-            description=(
-                "These options are used for the preparation of "
-                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-                "effective batch sizes, sampling strategies, applied data "
-                "augmentations, etc."
-            ),
+            description="These options are used for the preparation of "
+            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+            "effective batch sizes, sampling strategies, applied data "
+            "augmentations, etc.",
         )
         group.add_argument(
             "--manifest-dir",
@@ -79,90 +78,74 @@ class TedLiumAsrDataModule:
             "--max-duration",
             type=int,
             default=200.0,
-            help=(
-                "Maximum pooled recordings duration (seconds) in a "
-                "single batch. You can reduce it if it causes CUDA OOM."
-            ),
+            help="Maximum pooled recordings duration (seconds) in a "
+            "single batch. You can reduce it if it causes CUDA OOM.",
         )
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, the batches will come from buckets of "
-                "similar duration (saves padding frames)."
-            ),
+            help="When enabled, the batches will come from buckets of "
+            "similar duration (saves padding frames).",
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=30,
-            help=(
-                "The number of buckets for the DynamicBucketingSampler"
-                "(you might want to increase it for larger datasets)."
-            ),
+            help="The number of buckets for the DynamicBucketingSampler"
+            "(you might want to increase it for larger datasets).",
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, utterances (cuts) will be concatenated "
-                "to minimize the amount of padding."
-            ),
+            help="When enabled, utterances (cuts) will be concatenated "
+            "to minimize the amount of padding.",
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help=(
-                "Determines the maximum duration of a concatenated cut "
-                "relative to the duration of the longest cut in a batch."
-            ),
+            help="Determines the maximum duration of a concatenated cut "
+            "relative to the duration of the longest cut in a batch.",
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help=(
-                "The amount of padding (in seconds) inserted between "
-                "concatenated cuts. This padding is filled with noise when "
-                "noise augmentation is used."
-            ),
+            help="The amount of padding (in seconds) inserted between "
+            "concatenated cuts. This padding is filled with noise when "
+            "noise augmentation is used.",
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, use on-the-fly cut mixing and feature "
-                "extraction. Will drop existing precomputed feature manifests "
-                "if available."
-            ),
+            help="When enabled, use on-the-fly cut mixing and feature "
+            "extraction. Will drop existing precomputed feature manifests "
+            "if available.",
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled (=default), the examples will be shuffled for each epoch."
-            ),
+            help="When enabled (=default), the examples will be "
+            "shuffled for each epoch.",
         )
         group.add_argument(
             "--return-cuts",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, each batch will have the "
-                "field: batch['supervisions']['cut'] with the cuts that "
-                "were used to construct it."
-            ),
+            help="When enabled, each batch will have the "
+            "field: batch['supervisions']['cut'] with the cuts that "
+            "were used to construct it.",
         )
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that collect the batches.",
+            help="The number of training dataloader workers that "
+            "collect the batches.",
         )
         group.add_argument(
             "--enable-spec-aug",
@@ -174,25 +157,23 @@ class TedLiumAsrDataModule:
             "--spec-aug-time-warp-factor",
             type=int,
             default=80,
-            help=(
-                "Used only when --enable-spec-aug is True. "
-                "It specifies the factor for time warping in SpecAugment. "
-                "Larger values mean more warping. "
-                "A value less than 1 means to disable time warp."
-            ),
+            help="Used only when --enable-spec-aug is True. "
+            "It specifies the factor for time warping in SpecAugment. "
+            "Larger values mean more warping. "
+            "A value less than 1 means to disable time warp.",
         )
         group.add_argument(
             "--enable-musan",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, select noise from MUSAN and mix it"
-                "with training dataset. "
-            ),
+            help="When enabled, select noise from MUSAN and mix it"
+            "with training dataset.",
         )
 
     def train_dataloaders(
-        self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None
+        self,
+        cuts_train: CutSet,
+        sampler_state_dict: Optional[Dict[str, Any]] = None
     ) -> DataLoader:
         """
         Args:
@@ -205,7 +186,9 @@ class TedLiumAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+            logging.info(
+                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
+            )
 
             input_transforms.append(
                 SpecAugment(
@@ -225,16 +208,20 @@ class TedLiumAsrDataModule:
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
-            cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
+            cuts_musan = load_manifest(
+                self.args.manifest_dir / "musan_cuts.jsonl.gz"
+            )
             transforms.append(
-                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+                CutMix(
+                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
+                )
             )
         else:
             logging.info("Disable MUSAN")
 
         if self.args.concatenate_cuts:
             logging.info(
-                "Using cut concatenation with duration factor "
+                f"Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -260,7 +247,9 @@ class TedLiumAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -317,7 +306,9 @@ class TedLiumAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 return_cuts=self.args.return_cuts,
             )
         else:
@@ -348,7 +339,9 @@ class TedLiumAsrDataModule:
         logging.debug("About to create test dataset")
         if self.args.on_the_fly_feats:
             test = K2SpeechRecognitionDataset(
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 return_cuts=self.args.return_cuts,
             )
         else:
@@ -382,9 +375,13 @@ class TedLiumAsrDataModule:
     @lru_cache()
     def dev_cuts(self) -> CutSet:
         logging.info("About to get dev cuts")
-        return load_manifest_lazy(self.args.manifest_dir / "tedlium_cuts_dev.jsonl.gz")
+        return load_manifest_lazy(
+            self.args.manifest_dir / "tedlium_cuts_dev.jsonl.gz"
+        )
 
     @lru_cache()
     def test_cuts(self) -> CutSet:
         logging.info("About to get test cuts")
-        return load_manifest_lazy(self.args.manifest_dir / "tedlium_cuts_test.jsonl.gz")
+        return load_manifest_lazy(
+            self.args.manifest_dir / "tedlium_cuts_test.jsonl.gz"
+        )
diff --git a/egs/tedlium3/ASR/transducer_stateless/beam_search.py b/egs/tedlium3/ASR/transducer_stateless/beam_search.py
index 1f99edaf3..77caf6460 100644
--- a/egs/tedlium3/ASR/transducer_stateless/beam_search.py
+++ b/egs/tedlium3/ASR/transducer_stateless/beam_search.py
@@ -87,9 +87,9 @@ def greedy_search(
         y = logits.argmax().item()
         if y != blank_id and y != unk_id:
             hyp.append(y)
-            decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape(
-                1, context_size
-            )
+            decoder_input = torch.tensor(
+                [hyp[-context_size:]], device=device
+            ).reshape(1, context_size)
 
             decoder_out = model.decoder(decoder_input, need_pad=False)
 
@@ -148,7 +148,9 @@ class HypothesisList(object):
         key = hyp.key
         if key in self:
             old_hyp = self._data[key]  # shallow copy
-            torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob)
+            torch.logaddexp(
+                old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
+            )
         else:
             self._data[key] = hyp
 
@@ -164,7 +166,9 @@ class HypothesisList(object):
           Return the hypothesis that has the largest `log_prob`.
         """
         if length_norm:
-            return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys))
+            return max(
+                self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)
+            )
         else:
             return max(self._data.values(), key=lambda hyp: hyp.log_prob)
 
@@ -340,9 +344,9 @@ def modified_beam_search(
 
     device = model.device
 
-    decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
-        1, context_size
-    )
+    decoder_input = torch.tensor(
+        [blank_id] * context_size, device=device
+    ).reshape(1, context_size)
 
     decoder_out = model.decoder(decoder_input, need_pad=False)
 
@@ -379,7 +383,9 @@ def modified_beam_search(
         decoder_out = model.decoder(decoder_input, need_pad=False)
         # decoder_output is of shape (num_hyps, 1, decoder_output_dim)
 
-        current_encoder_out = current_encoder_out.expand(decoder_out.size(0), 1, -1)
+        current_encoder_out = current_encoder_out.expand(
+            decoder_out.size(0), 1, -1
+        )
 
         logits = model.joiner(
             current_encoder_out,
@@ -448,9 +454,9 @@ def beam_search(
 
     device = model.device
 
-    decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
-        1, context_size
-    )
+    decoder_input = torch.tensor(
+        [blank_id] * context_size, device=device
+    ).reshape(1, context_size)
 
     decoder_out = model.decoder(decoder_input, need_pad=False)
 
diff --git a/egs/tedlium3/ASR/transducer_stateless/decode.py b/egs/tedlium3/ASR/transducer_stateless/decode.py
index 12d0e2652..d3e9e55e7 100755
--- a/egs/tedlium3/ASR/transducer_stateless/decode.py
+++ b/egs/tedlium3/ASR/transducer_stateless/decode.py
@@ -81,19 +81,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=29,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=13,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -133,7 +130,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -252,7 +250,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
     hyps = []
     batch_size = encoder_out.size(0)
 
@@ -275,7 +275,9 @@ def decode_one_batch(
                 model=model, encoder_out=encoder_out_i, beam=params.beam_size
             )
         else:
-            raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
+            raise ValueError(
+                f"Unsupported decoding method: {params.decoding_method}"
+            )
         hyps.append(sp.decode(hyp).split())
 
     if params.decoding_method == "greedy_search":
@@ -346,7 +348,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -379,7 +383,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
diff --git a/egs/tedlium3/ASR/transducer_stateless/decoder.py b/egs/tedlium3/ASR/transducer_stateless/decoder.py
index f9a3814c6..f0c6f32b6 100644
--- a/egs/tedlium3/ASR/transducer_stateless/decoder.py
+++ b/egs/tedlium3/ASR/transducer_stateless/decoder.py
@@ -90,7 +90,9 @@ class Decoder(nn.Module):
         if self.context_size > 1:
             embedding_out = embedding_out.permute(0, 2, 1)
             if need_pad is True:
-                embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
+                embedding_out = F.pad(
+                    embedding_out, pad=(self.context_size - 1, 0)
+                )
             else:
                 # During inference time, there is no need to do extra padding
                 # as we only need one output
diff --git a/egs/tedlium3/ASR/transducer_stateless/export.py b/egs/tedlium3/ASR/transducer_stateless/export.py
index 0b2ae970b..c32b1d002 100644
--- a/egs/tedlium3/ASR/transducer_stateless/export.py
+++ b/egs/tedlium3/ASR/transducer_stateless/export.py
@@ -69,20 +69,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=20,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=10,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -113,7 +110,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     return parser
@@ -249,7 +247,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tedlium3/ASR/transducer_stateless/pretrained.py b/egs/tedlium3/ASR/transducer_stateless/pretrained.py
index 912d65497..c0e3bb844 100644
--- a/egs/tedlium3/ASR/transducer_stateless/pretrained.py
+++ b/egs/tedlium3/ASR/transducer_stateless/pretrained.py
@@ -82,11 +82,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -112,12 +110,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -131,7 +127,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -225,9 +222,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -287,7 +285,9 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -335,7 +335,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tedlium3/ASR/transducer_stateless/train.py b/egs/tedlium3/ASR/transducer_stateless/train.py
index 6fed32e81..09cbf4a00 100755
--- a/egs/tedlium3/ASR/transducer_stateless/train.py
+++ b/egs/tedlium3/ASR/transducer_stateless/train.py
@@ -133,7 +133,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -524,7 +525,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -644,7 +647,9 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
+            tb_writer.add_scalar(
+                "train/learning_rate", cur_lr, params.batch_idx_train
+            )
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/timit/ASR/RESULTS.md b/egs/timit/ASR/RESULTS.md
index d8ceb82b6..b78c16b88 100644
--- a/egs/timit/ASR/RESULTS.md
+++ b/egs/timit/ASR/RESULTS.md
@@ -71,4 +71,4 @@ python tdnn_ligru_ctc/decode.py --epoch 25 \
                                --avg 17 \
                                --max-duration 20 \
                                --lang-dir data/lang_phone
-```
+```
\ No newline at end of file
diff --git a/egs/timit/ASR/local/compile_hlg.py b/egs/timit/ASR/local/compile_hlg.py
index 32c248d7e..58cab4cf2 100644
--- a/egs/timit/ASR/local/compile_hlg.py
+++ b/egs/timit/ASR/local/compile_hlg.py
@@ -146,7 +146,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/timit/ASR/local/compute_fbank_timit.py b/egs/timit/ASR/local/compute_fbank_timit.py
index ecdf10ba9..f25786a0c 100644
--- a/egs/timit/ASR/local/compute_fbank_timit.py
+++ b/egs/timit/ASR/local/compute_fbank_timit.py
@@ -85,7 +85,9 @@ def compute_fbank_timit():
             )
             if partition == "TRAIN":
                 cut_set = (
-                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+                    cut_set
+                    + cut_set.perturb_speed(0.9)
+                    + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -99,7 +101,9 @@ def compute_fbank_timit():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/timit/ASR/local/prepare_lexicon.py b/egs/timit/ASR/local/prepare_lexicon.py
index 0cf0f0deb..04023a9ab 100644
--- a/egs/timit/ASR/local/prepare_lexicon.py
+++ b/egs/timit/ASR/local/prepare_lexicon.py
@@ -62,7 +62,9 @@ def prepare_lexicon(manifests_dir: str, lang_dir: str):
 
     phones = set()
 
-    supervisions_train = Path(manifests_dir) / "timit_supervisions_TRAIN.jsonl.gz"
+    supervisions_train = (
+        Path(manifests_dir) / "timit_supervisions_TRAIN.jsonl.gz"
+    )
     lexicon = Path(lang_dir) / "lexicon.txt"
 
     logging.info(f"Loading {supervisions_train}!")
@@ -95,7 +97,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/timit/ASR/prepare.sh b/egs/timit/ASR/prepare.sh
index d11cd3a05..ae1b96a68 100644
--- a/egs/timit/ASR/prepare.sh
+++ b/egs/timit/ASR/prepare.sh
@@ -20,9 +20,9 @@ stop_stage=100
 #  - $dl_dir/lm
 #      This directory contains the language model(LM) downloaded from
 #      https://huggingface.co/luomingshuang/timit_lm, and the LM is based
-#	     on 39 phones. About how to get these LM files, you can know it
+#	     on 39 phones. About how to get these LM files, you can know it 
 #      from https://github.com/luomingshuang/Train_LM_with_kaldilm.
-#
+#	
 #	    - lm_3_gram.arpa
 #     - lm_4_gram.arpa
 #
diff --git a/egs/timit/ASR/tdnn_ligru_ctc/decode.py b/egs/timit/ASR/tdnn_ligru_ctc/decode.py
index 5a59a13ce..4f2aa2340 100644
--- a/egs/timit/ASR/tdnn_ligru_ctc/decode.py
+++ b/egs/timit/ASR/tdnn_ligru_ctc/decode.py
@@ -57,19 +57,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=19,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=5,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
     parser.add_argument(
         "--method",
@@ -339,7 +336,9 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -401,7 +400,9 @@ def main():
 
     logging.info(f"device: {device}")
 
-    HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
+    HLG = k2.Fsa.from_dict(
+        torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
+    )
     HLG = HLG.to(device)
     assert HLG.requires_grad is False
 
@@ -461,7 +462,9 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
+        torch.save(
+            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
+        )
         return
 
     model.to(device)
@@ -482,7 +485,9 @@ def main():
         G=G,
     )
 
-    save_results(params=params, test_set_name=test_set, results_dict=results_dict)
+    save_results(
+        params=params, test_set_name=test_set, results_dict=results_dict
+    )
 
     logging.info("Done!")
 
diff --git a/egs/timit/ASR/tdnn_ligru_ctc/model.py b/egs/timit/ASR/tdnn_ligru_ctc/model.py
index 9a594a969..4d2199ace 100644
--- a/egs/timit/ASR/tdnn_ligru_ctc/model.py
+++ b/egs/timit/ASR/tdnn_ligru_ctc/model.py
@@ -16,11 +16,11 @@
 # limitations under the License.
 
 
-from typing import Optional
-
 import torch
 import torch.nn as nn
+
 from torch import Tensor
+from typing import Optional
 
 
 class TdnnLiGRU(nn.Module):
@@ -261,7 +261,9 @@ class LiGRU(torch.nn.Module):
         h = []
         if hx is not None:
             if self.bidirectional:
-                hx = hx.reshape(self.num_layers, self.batch_size * 2, self.hidden_size)
+                hx = hx.reshape(
+                    self.num_layers, self.batch_size * 2, self.hidden_size
+                )
         # Processing the different layers
         for i, ligru_lay in enumerate(self.rnn):
             if hx is not None:
@@ -443,7 +445,9 @@ class LiGRU_Layer(torch.nn.Module):
             if self.drop_mask_cnt + self.batch_size > self.N_drop_masks:
                 self.drop_mask_cnt = 0
                 self.drop_masks = self.drop(
-                    torch.ones(self.N_drop_masks, self.hidden_size, device=w.device)
+                    torch.ones(
+                        self.N_drop_masks, self.hidden_size, device=w.device
+                    )
                 ).data
 
             # Sampling the mask
diff --git a/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py b/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py
index da669bc39..7da285944 100644
--- a/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py
+++ b/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py
@@ -29,7 +29,11 @@ import torchaudio
 from model import TdnnLiGRU
 from torch.nn.utils.rnn import pad_sequence
 
-from icefall.decode import get_lattice, one_best_decoding, rescore_with_whole_lattice
+from icefall.decode import (
+    get_lattice,
+    one_best_decoding,
+    rescore_with_whole_lattice,
+)
 from icefall.utils import AttributeDict, get_texts
 
 
@@ -42,11 +46,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -56,7 +58,9 @@ def get_parser():
         help="Path to words.txt",
     )
 
-    parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
+    parser.add_argument(
+        "--HLG", type=str, required=True, help="Path to HLG.pt."
+    )
 
     parser.add_argument(
         "--method",
@@ -99,12 +103,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     return parser
@@ -142,9 +144,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -212,7 +215,9 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
     features = features.permute(0, 2, 1)  # now features is (N, C, T)
 
     with torch.no_grad():
@@ -264,7 +269,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/timit/ASR/tdnn_ligru_ctc/train.py b/egs/timit/ASR/tdnn_ligru_ctc/train.py
index 48b7feda0..452c2a7cb 100644
--- a/egs/timit/ASR/tdnn_ligru_ctc/train.py
+++ b/egs/timit/ASR/tdnn_ligru_ctc/train.py
@@ -449,7 +449,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py
index d957c22e1..1554e987f 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py
@@ -63,12 +63,10 @@ class TimitAsrDataModule(DataModule):
         super().add_arguments(parser)
         group = parser.add_argument_group(
             title="ASR data related options",
-            description=(
-                "These options are used for the preparation of "
-                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-                "effective batch sizes, sampling strategies, applied data "
-                "augmentations, etc."
-            ),
+            description="These options are used for the preparation of "
+            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+            "effective batch sizes, sampling strategies, applied data "
+            "augmentations, etc.",
         )
         group.add_argument(
             "--feature-dir",
@@ -80,91 +78,75 @@ class TimitAsrDataModule(DataModule):
             "--max-duration",
             type=int,
             default=200.0,
-            help=(
-                "Maximum pooled recordings duration (seconds) in a "
-                "single batch. You can reduce it if it causes CUDA OOM."
-            ),
+            help="Maximum pooled recordings duration (seconds) in a "
+            "single batch. You can reduce it if it causes CUDA OOM.",
         )
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, the batches will come from buckets of "
-                "similar duration (saves padding frames)."
-            ),
+            help="When enabled, the batches will come from buckets of "
+            "similar duration (saves padding frames).",
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=30,
-            help=(
-                "The number of buckets for the DynamicBucketingSampler"
-                "(you might want to increase it for larger datasets)."
-            ),
+            help="The number of buckets for the DynamicBucketingSampler"
+            "(you might want to increase it for larger datasets).",
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, utterances (cuts) will be concatenated "
-                "to minimize the amount of padding."
-            ),
+            help="When enabled, utterances (cuts) will be concatenated "
+            "to minimize the amount of padding.",
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help=(
-                "Determines the maximum duration of a concatenated cut "
-                "relative to the duration of the longest cut in a batch."
-            ),
+            help="Determines the maximum duration of a concatenated cut "
+            "relative to the duration of the longest cut in a batch.",
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help=(
-                "The amount of padding (in seconds) inserted between "
-                "concatenated cuts. This padding is filled with noise when "
-                "noise augmentation is used."
-            ),
+            help="The amount of padding (in seconds) inserted between "
+            "concatenated cuts. This padding is filled with noise when "
+            "noise augmentation is used.",
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, use on-the-fly cut mixing and feature "
-                "extraction. Will drop existing precomputed feature manifests "
-                "if available."
-            ),
+            help="When enabled, use on-the-fly cut mixing and feature "
+            "extraction. Will drop existing precomputed feature manifests "
+            "if available.",
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled (=default), the examples will be shuffled for each epoch."
-            ),
+            help="When enabled (=default), the examples will be "
+            "shuffled for each epoch.",
         )
         group.add_argument(
             "--return-cuts",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, each batch will have the "
-                "field: batch['supervisions']['cut'] with the cuts that "
-                "were used to construct it."
-            ),
+            help="When enabled, each batch will have the "
+            "field: batch['supervisions']['cut'] with the cuts that "
+            "were used to construct it.",
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that collect the batches.",
+            help="The number of training dataloader workers that "
+            "collect the batches.",
         )
 
     def train_dataloaders(self) -> DataLoader:
@@ -172,13 +154,15 @@ class TimitAsrDataModule(DataModule):
         cuts_train = self.train_cuts()
 
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(self.args.feature_dir / "musan_cuts.jsonl.gz")
+        cuts_musan = load_manifest(
+            self.args.feature_dir / "musan_cuts.jsonl.gz"
+        )
 
         logging.info("About to create train dataset")
         transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))]
         if self.args.concatenate_cuts:
             logging.info(
-                "Using cut concatenation with duration factor "
+                f"Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -194,9 +178,9 @@ class TimitAsrDataModule(DataModule):
         # In different Lhotse's versions, the default of num_frame_masks is
         # different.
         num_frame_masks = 10
-        num_frame_masks_parameter = inspect.signature(SpecAugment.__init__).parameters[
-            "num_frame_masks"
-        ]
+        num_frame_masks_parameter = inspect.signature(
+            SpecAugment.__init__
+        ).parameters["num_frame_masks"]
         if num_frame_masks_parameter.default == 1:
             num_frame_masks = 2
         logging.info(f"Num frame mask: {num_frame_masks}")
@@ -228,7 +212,9 @@ class TimitAsrDataModule(DataModule):
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -277,7 +263,9 @@ class TimitAsrDataModule(DataModule):
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 return_cuts=self.args.return_cuts,
             )
         else:
@@ -311,14 +299,20 @@ class TimitAsrDataModule(DataModule):
         for cuts_test in cuts:
             logging.debug("About to create test dataset")
             test = K2SpeechRecognitionDataset(
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                )
                 if self.args.on_the_fly_feats
                 else PrecomputedFeatures(),
                 return_cuts=self.args.return_cuts,
             )
-            sampler = SingleCutSampler(cuts_test, max_duration=self.args.max_duration)
+            sampler = SingleCutSampler(
+                cuts_test, max_duration=self.args.max_duration
+            )
             logging.debug("About to create test dataloader")
-            test_dl = DataLoader(test, batch_size=None, sampler=sampler, num_workers=1)
+            test_dl = DataLoader(
+                test, batch_size=None, sampler=sampler, num_workers=1
+            )
             test_loaders.append(test_dl)
 
         if is_list:
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/decode.py b/egs/timit/ASR/tdnn_lstm_ctc/decode.py
index 319ee5515..5e7300cf2 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/decode.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/decode.py
@@ -56,19 +56,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=25,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=5,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
     parser.add_argument(
         "--method",
@@ -338,7 +335,9 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -400,7 +399,9 @@ def main():
 
     logging.info(f"device: {device}")
 
-    HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
+    HLG = k2.Fsa.from_dict(
+        torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
+    )
     HLG = HLG.to(device)
     assert HLG.requires_grad is False
 
@@ -460,7 +461,9 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
+        torch.save(
+            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
+        )
         return
 
     model.to(device)
@@ -480,7 +483,9 @@ def main():
         G=G,
     )
 
-    save_results(params=params, test_set_name=test_set, results_dict=results_dict)
+    save_results(
+        params=params, test_set_name=test_set, results_dict=results_dict
+    )
 
     logging.info("Done!")
 
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/model.py b/egs/timit/ASR/tdnn_lstm_ctc/model.py
index e211ad80d..51edb97e2 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/model.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/model.py
@@ -74,7 +74,10 @@ class TdnnLstm(nn.Module):
             nn.BatchNorm1d(num_features=512, affine=False),
         )
         self.lstms = nn.ModuleList(
-            [nn.LSTM(input_size=512, hidden_size=512, num_layers=1) for _ in range(4)]
+            [
+                nn.LSTM(input_size=512, hidden_size=512, num_layers=1)
+                for _ in range(4)
+            ]
         )
         self.lstm_bnorms = nn.ModuleList(
             [nn.BatchNorm1d(num_features=512, affine=False) for _ in range(5)]
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py b/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py
index 0c72c973b..5f478da1c 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py
@@ -29,7 +29,11 @@ import torchaudio
 from model import TdnnLstm
 from torch.nn.utils.rnn import pad_sequence
 
-from icefall.decode import get_lattice, one_best_decoding, rescore_with_whole_lattice
+from icefall.decode import (
+    get_lattice,
+    one_best_decoding,
+    rescore_with_whole_lattice,
+)
 from icefall.utils import AttributeDict, get_texts
 
 
@@ -42,11 +46,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -56,7 +58,9 @@ def get_parser():
         help="Path to words.txt",
     )
 
-    parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
+    parser.add_argument(
+        "--HLG", type=str, required=True, help="Path to HLG.pt."
+    )
 
     parser.add_argument(
         "--method",
@@ -99,12 +103,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     return parser
@@ -142,9 +144,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -212,7 +215,9 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
     features = features.permute(0, 2, 1)  # now features is (N, C, T)
 
     with torch.no_grad():
@@ -264,7 +269,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/train.py b/egs/timit/ASR/tdnn_lstm_ctc/train.py
index be1ecffaa..849256b98 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/train.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/train.py
@@ -449,7 +449,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py
index bd73e520e..8a9f6ed30 100755
--- a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py
+++ b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py
@@ -20,7 +20,12 @@ import logging
 from pathlib import Path
 
 import torch
-from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomHdf5Writer
+from lhotse import (
+    CutSet,
+    KaldifeatFbank,
+    KaldifeatFbankConfig,
+    LilcomHdf5Writer,
+)
 
 # Torch's multithreaded behavior needs to be disabled or
 # it wastes a lot of CPU and slow things down.
@@ -78,7 +83,9 @@ def compute_fbank_wenetspeech_dev_test():
 
 
 def main():
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
     logging.basicConfig(format=formatter, level=logging.INFO)
 
     compute_fbank_wenetspeech_dev_test()
diff --git a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py
index c228597b8..a882b6113 100755
--- a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py
+++ b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py
@@ -62,10 +62,8 @@ def get_parser():
         "--batch-duration",
         type=float,
         default=600.0,
-        help=(
-            "The maximum number of audio seconds in a batch."
-            "Determines batch size dynamically."
-        ),
+        help="The maximum number of audio seconds in a batch."
+        "Determines batch size dynamically.",
     )
 
     parser.add_argument(
@@ -154,7 +152,9 @@ def main():
     date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
 
     log_filename = "log-compute_fbank_wenetspeech_splits"
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
     log_filename = f"{log_filename}-{date_time}"
 
     logging.basicConfig(
diff --git a/egs/wenetspeech/ASR/local/prepare_char.py b/egs/wenetspeech/ASR/local/prepare_char.py
index d8622842f..8bc073c75 100755
--- a/egs/wenetspeech/ASR/local/prepare_char.py
+++ b/egs/wenetspeech/ASR/local/prepare_char.py
@@ -83,7 +83,9 @@ def lexicon_to_fst_no_sil(
         cur_state = loop_state
 
         word = word2id[word]
-        pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces]
+        pieces = [
+            token2id[i] if i in token2id else token2id[""] for i in pieces
+        ]
 
         for i in range(len(pieces) - 1):
             w = word if i == 0 else eps
@@ -136,7 +138,9 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
     return False
 
 
-def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
+def generate_lexicon(
+    token_sym_table: Dict[str, int], words: List[str]
+) -> Lexicon:
     """Generate a lexicon from a word list and token_sym_table.
     Args:
       token_sym_table:
diff --git a/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py b/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py
index 93ce750f8..817969c47 100755
--- a/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py
+++ b/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py
@@ -115,7 +115,11 @@ def preprocess_wenet_speech():
                 f"Speed perturb for {partition} with factors 0.9 and 1.1 "
                 "(Perturbing may take 8 minutes and saving may take 20 minutes)"
             )
-            cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+            cut_set = (
+                cut_set
+                + cut_set.perturb_speed(0.9)
+                + cut_set.perturb_speed(1.1)
+            )
         logging.info(f"Saving to {raw_cuts_path}")
         cut_set.to_file(raw_cuts_path)
 
diff --git a/egs/wenetspeech/ASR/local/text2token.py b/egs/wenetspeech/ASR/local/text2token.py
index e121d842c..1c463cf1c 100755
--- a/egs/wenetspeech/ASR/local/text2token.py
+++ b/egs/wenetspeech/ASR/local/text2token.py
@@ -50,15 +50,15 @@ def get_parser():
         "-n",
         default=1,
         type=int,
-        help=(
-            "number of characters to split, i.e.,                         aabb -> a a b"
-            " b with -n 1 and aa bb with -n 2"
-        ),
+        help="number of characters to split, i.e., \
+                        aabb -> a a b b with -n 1 and aa bb with -n 2",
     )
     parser.add_argument(
         "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
     )
-    parser.add_argument("--space", default="", type=str, help="space symbol")
+    parser.add_argument(
+        "--space", default="", type=str, help="space symbol"
+    )
     parser.add_argument(
         "--non-lang-syms",
         "-l",
@@ -66,7 +66,9 @@ def get_parser():
         type=str,
         help="list of non-linguistic symobles, e.g.,  etc.",
     )
-    parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
+    parser.add_argument(
+        "text", type=str, default=False, nargs="?", help="input text"
+    )
     parser.add_argument(
         "--trans_type",
         "-t",
@@ -106,7 +108,8 @@ def token2id(
             if token_type == "lazy_pinyin":
                 text = lazy_pinyin(chars_list)
                 sub_ids = [
-                    token_table[txt] if txt in token_table else oov_id for txt in text
+                    token_table[txt] if txt in token_table else oov_id
+                    for txt in text
                 ]
                 ids.append(sub_ids)
             else:  # token_type = "pinyin"
@@ -132,7 +135,9 @@ def main():
     if args.text:
         f = codecs.open(args.text, encoding="utf-8")
     else:
-        f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
+        f = codecs.getreader("utf-8")(
+            sys.stdin if is_python2 else sys.stdin.buffer
+        )
 
     sys.stdout = codecs.getwriter("utf-8")(
         sys.stdout if is_python2 else sys.stdout.buffer
diff --git a/egs/wenetspeech/ASR/prepare.sh b/egs/wenetspeech/ASR/prepare.sh
index da7d7e061..755fbb2d7 100755
--- a/egs/wenetspeech/ASR/prepare.sh
+++ b/egs/wenetspeech/ASR/prepare.sh
@@ -190,7 +190,7 @@ if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then
   mkdir -p $lang_char_dir
 
   if ! which jq; then
-      echo "This script is intended to be used with jq but you have not installed jq
+      echo "This script is intended to be used with jq but you have not installed jq 
       Note: in Linux, you can install jq with the following command:
       1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64
       2. chmod +x ./jq
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
index bd92ac115..10c953e3b 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -81,12 +81,10 @@ class WenetSpeechAsrDataModule:
     def add_arguments(cls, parser: argparse.ArgumentParser):
         group = parser.add_argument_group(
             title="ASR data related options",
-            description=(
-                "These options are used for the preparation of "
-                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-                "effective batch sizes, sampling strategies, applied data "
-                "augmentations, etc."
-            ),
+            description="These options are used for the preparation of "
+            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+            "effective batch sizes, sampling strategies, applied data "
+            "augmentations, etc.",
         )
         group.add_argument(
             "--manifest-dir",
@@ -98,91 +96,75 @@ class WenetSpeechAsrDataModule:
             "--max-duration",
             type=int,
             default=200.0,
-            help=(
-                "Maximum pooled recordings duration (seconds) in a "
-                "single batch. You can reduce it if it causes CUDA OOM."
-            ),
+            help="Maximum pooled recordings duration (seconds) in a "
+            "single batch. You can reduce it if it causes CUDA OOM.",
         )
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, the batches will come from buckets of "
-                "similar duration (saves padding frames)."
-            ),
+            help="When enabled, the batches will come from buckets of "
+            "similar duration (saves padding frames).",
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=300,
-            help=(
-                "The number of buckets for the DynamicBucketingSampler"
-                "(you might want to increase it for larger datasets)."
-            ),
+            help="The number of buckets for the DynamicBucketingSampler"
+            "(you might want to increase it for larger datasets).",
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, utterances (cuts) will be concatenated "
-                "to minimize the amount of padding."
-            ),
+            help="When enabled, utterances (cuts) will be concatenated "
+            "to minimize the amount of padding.",
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help=(
-                "Determines the maximum duration of a concatenated cut "
-                "relative to the duration of the longest cut in a batch."
-            ),
+            help="Determines the maximum duration of a concatenated cut "
+            "relative to the duration of the longest cut in a batch.",
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help=(
-                "The amount of padding (in seconds) inserted between "
-                "concatenated cuts. This padding is filled with noise when "
-                "noise augmentation is used."
-            ),
+            help="The amount of padding (in seconds) inserted between "
+            "concatenated cuts. This padding is filled with noise when "
+            "noise augmentation is used.",
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, use on-the-fly cut mixing and feature "
-                "extraction. Will drop existing precomputed feature manifests "
-                "if available."
-            ),
+            help="When enabled, use on-the-fly cut mixing and feature "
+            "extraction. Will drop existing precomputed feature manifests "
+            "if available.",
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled (=default), the examples will be shuffled for each epoch."
-            ),
+            help="When enabled (=default), the examples will be "
+            "shuffled for each epoch.",
         )
         group.add_argument(
             "--return-cuts",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, each batch will have the "
-                "field: batch['supervisions']['cut'] with the cuts that "
-                "were used to construct it."
-            ),
+            help="When enabled, each batch will have the "
+            "field: batch['supervisions']['cut'] with the cuts that "
+            "were used to construct it.",
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that collect the batches.",
+            help="The number of training dataloader workers that "
+            "collect the batches.",
         )
 
         group.add_argument(
@@ -196,22 +178,18 @@ class WenetSpeechAsrDataModule:
             "--spec-aug-time-warp-factor",
             type=int,
             default=80,
-            help=(
-                "Used only when --enable-spec-aug is True. "
-                "It specifies the factor for time warping in SpecAugment. "
-                "Larger values mean more warping. "
-                "A value less than 1 means to disable time warp."
-            ),
+            help="Used only when --enable-spec-aug is True. "
+            "It specifies the factor for time warping in SpecAugment. "
+            "Larger values mean more warping. "
+            "A value less than 1 means to disable time warp.",
         )
 
         group.add_argument(
             "--enable-musan",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, select noise from MUSAN and mix it"
-                "with training dataset. "
-            ),
+            help="When enabled, select noise from MUSAN and mix it"
+            "with training dataset. ",
         )
 
         group.add_argument(
@@ -234,20 +212,24 @@ class WenetSpeechAsrDataModule:
             The state dict for the training sampler.
         """
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
+        cuts_musan = load_manifest(
+            self.args.manifest_dir / "musan_cuts.jsonl.gz"
+        )
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+                CutMix(
+                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
+                )
             )
         else:
             logging.info("Disable MUSAN")
 
         if self.args.concatenate_cuts:
             logging.info(
-                "Using cut concatenation with duration factor "
+                f"Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -262,7 +244,9 @@ class WenetSpeechAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+            logging.info(
+                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
+            )
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -305,7 +289,9 @@ class WenetSpeechAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -362,7 +348,9 @@ class WenetSpeechAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_strategy=OnTheFlyFeatures(
+                    Fbank(FbankConfig(num_mel_bins=80))
+                ),
                 return_cuts=self.args.return_cuts,
             )
         else:
@@ -426,7 +414,8 @@ class WenetSpeechAsrDataModule:
     def train_cuts(self) -> CutSet:
         logging.info("About to get train cuts")
         cuts_train = load_manifest_lazy(
-            self.args.manifest_dir / f"cuts_{self.args.training_subset}.jsonl.gz"
+            self.args.manifest_dir
+            / f"cuts_{self.args.training_subset}.jsonl.gz"
         )
         return cuts_train
 
@@ -438,9 +427,13 @@ class WenetSpeechAsrDataModule:
     @lru_cache()
     def test_net_cuts(self) -> List[CutSet]:
         logging.info("About to get TEST_NET cuts")
-        return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_NET.jsonl.gz")
+        return load_manifest_lazy(
+            self.args.manifest_dir / "cuts_TEST_NET.jsonl.gz"
+        )
 
     @lru_cache()
     def test_meeting_cuts(self) -> List[CutSet]:
         logging.info("About to get TEST_MEETING cuts")
-        return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz")
+        return load_manifest_lazy(
+            self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz"
+        )
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py
index 6e856248c..f0c9bebec 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py
@@ -114,7 +114,11 @@ from beam_search import (
 from train import get_params, get_transducer_model
 
 from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
-from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
+from icefall.checkpoint import (
+    average_checkpoints,
+    find_checkpoints,
+    load_checkpoint,
+)
 from icefall.lexicon import Lexicon
 from icefall.utils import (
     AttributeDict,
@@ -133,30 +137,25 @@ def get_parser():
         "--epoch",
         type=int,
         default=28,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--batch",
         type=int,
         default=None,
-        help=(
-            "It specifies the batch checkpoint to use for decoding."
-            "Note: Epoch counts from 0."
-        ),
+        help="It specifies the batch checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -253,7 +252,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -328,7 +328,9 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -387,7 +389,10 @@ def decode_one_batch(
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -433,7 +438,11 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -506,7 +515,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -539,7 +550,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -651,7 +663,9 @@ def main():
             )
             decoding_graph.scores *= params.ngram_lm_scale
         else:
-            decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+            decoding_graph = k2.trivial_graph(
+                params.vocab_size - 1, device=device
+            )
     else:
         decoding_graph = None
 
@@ -702,7 +716,8 @@ def main():
         )
 
     dev_shards = [
-        str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
+        str(path)
+        for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
     ]
     cuts_dev_webdataset = CutSet.from_webdataset(
         dev_shards,
@@ -712,7 +727,8 @@ def main():
     )
 
     test_net_shards = [
-        str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
+        str(path)
+        for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
     ]
     cuts_test_net_webdataset = CutSet.from_webdataset(
         test_net_shards,
@@ -723,7 +739,9 @@ def main():
 
     test_meeting_shards = [
         str(path)
-        for path in sorted(glob.glob(os.path.join(test_meeting, "shared-*.tar")))
+        for path in sorted(
+            glob.glob(os.path.join(test_meeting, "shared-*.tar"))
+        )
     ]
     cuts_test_meeting_webdataset = CutSet.from_webdataset(
         test_meeting_shards,
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
index c742593df..933642a0f 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
@@ -126,20 +126,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=28,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -208,7 +205,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     return parser
@@ -470,9 +468,13 @@ def export_joiner_model_onnx(
 
         - projected_decoder_out: a tensor of shape (N, joiner_dim)
     """
-    encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx")
+    encoder_proj_filename = str(joiner_filename).replace(
+        ".onnx", "_encoder_proj.onnx"
+    )
 
-    decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx")
+    decoder_proj_filename = str(joiner_filename).replace(
+        ".onnx", "_decoder_proj.onnx"
+    )
 
     encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
     decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
@@ -643,7 +645,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py
index ed9020c67..e5cc47bfe 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py
@@ -107,12 +107,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -147,9 +145,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -332,7 +331,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py
index a46ff5a07..c396c50ef 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py
@@ -219,7 +219,9 @@ def test_joiner(
         )
 
         # Now test encoder_proj
-        joiner_encoder_proj_inputs = {encoder_proj_input_name: encoder_out.numpy()}
+        joiner_encoder_proj_inputs = {
+            encoder_proj_input_name: encoder_out.numpy()
+        }
         joiner_encoder_proj_out = joiner_encoder_proj_session.run(
             [encoder_proj_output_name], joiner_encoder_proj_inputs
         )[0]
@@ -228,10 +230,16 @@ def test_joiner(
         torch_joiner_encoder_proj_out = model.joiner.encoder_proj(encoder_out)
         assert torch.allclose(
             joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5
-        ), ((joiner_encoder_proj_out - torch_joiner_encoder_proj_out).abs().max())
+        ), (
+            (joiner_encoder_proj_out - torch_joiner_encoder_proj_out)
+            .abs()
+            .max()
+        )
 
         # Now test decoder_proj
-        joiner_decoder_proj_inputs = {decoder_proj_input_name: decoder_out.numpy()}
+        joiner_decoder_proj_inputs = {
+            decoder_proj_input_name: decoder_out.numpy()
+        }
         joiner_decoder_proj_out = joiner_decoder_proj_session.run(
             [decoder_proj_output_name], joiner_decoder_proj_inputs
         )[0]
@@ -240,7 +248,11 @@ def test_joiner(
         torch_joiner_decoder_proj_out = model.joiner.decoder_proj(decoder_out)
         assert torch.allclose(
             joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5
-        ), ((joiner_decoder_proj_out - torch_joiner_decoder_proj_out).abs().max())
+        ), (
+            (joiner_decoder_proj_out - torch_joiner_decoder_proj_out)
+            .abs()
+            .max()
+        )
 
 
 @torch.no_grad()
@@ -292,7 +304,9 @@ def main():
 
 if __name__ == "__main__":
     torch.manual_seed(20220727)
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py
index f7d962008..3770fbbb4 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py
@@ -111,12 +111,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -151,9 +149,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -201,7 +200,11 @@ def greedy_search(
 
     projected_encoder_out = joiner_encoder_proj.run(
         [joiner_encoder_proj.get_outputs()[0].name],
-        {joiner_encoder_proj.get_inputs()[0].name: packed_encoder_out.data.numpy()},
+        {
+            joiner_encoder_proj.get_inputs()[
+                0
+            ].name: packed_encoder_out.data.numpy()
+        },
     )[0]
 
     blank_id = 0  # hard-code to 0
@@ -386,7 +389,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py
index 26c9c2b8c..9a549efd9 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py
@@ -80,11 +80,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -109,12 +107,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -162,7 +158,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -192,9 +189,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -255,7 +253,9 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -280,7 +280,10 @@ def main():
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -332,7 +335,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py
index e020c4c05..d3cc7c9c9 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py
@@ -115,7 +115,9 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+LRSchedulerType = Union[
+    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
+]
 
 
 def get_parser():
@@ -217,45 +219,42 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
         "--prune-range",
         type=int,
         default=5,
-        help=(
-            "The prune range for rnnt loss, it means how many symbols(context)"
-            "we are using to compute the loss"
-        ),
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
     )
 
     parser.add_argument(
         "--lm-scale",
         type=float,
         default=0.25,
-        help=(
-            "The scale to smooth the loss with lm (output of prediction network) part."
-        ),
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
     )
 
     parser.add_argument(
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)part.",
+        help="The scale to smooth the loss with am (output of encoder network)"
+        "part.",
     )
 
     parser.add_argument(
         "--simple-loss-scale",
         type=float,
         default=0.5,
-        help=(
-            "To get pruning ranges, we will calculate a simple version"
-            "loss(joiner is just addition), this simple loss also uses for"
-            "training (as a regularization item). We will scale the simple loss"
-            "with this parameter before adding to the final loss."
-        ),
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
     )
 
     parser.add_argument(
@@ -591,15 +590,22 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+            0.0
+            if warmup < 1.0
+            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+        )
+        loss = (
+            params.simple_loss_scale * simple_loss
+            + pruned_loss_scale * pruned_loss
         )
-        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -756,7 +762,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -856,7 +864,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2**22
+            2 ** 22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py
index 1023c931a..dd27c17f0 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py
@@ -210,7 +210,10 @@ class Conformer(EncoderInterface):
           (num_encoder_layers, cnn_module_kernel - 1, encoder_dim).
           NOTE: the returned tensors are on the given device.
         """
-        if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context:
+        if (
+            len(self._init_state) == 2
+            and self._init_state[0].size(1) == left_context
+        ):
             # Note: It is OK to share the init state as it is
             # not going to be modified by the model
             return self._init_state
@@ -430,7 +433,9 @@ class ConformerEncoderLayer(nn.Module):
 
         self.d_model = d_model
 
-        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
+        self.self_attn = RelPositionMultiheadAttention(
+            d_model, nhead, dropout=0.0
+        )
 
         self.feed_forward = nn.Sequential(
             ScaledLinear(d_model, dim_feedforward),
@@ -448,7 +453,9 @@ class ConformerEncoderLayer(nn.Module):
             ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
         )
 
-        self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal)
+        self.conv_module = ConvolutionModule(
+            d_model, cnn_module_kernel, causal=causal
+        )
 
         self.norm_final = BasicNorm(d_model)
 
@@ -513,7 +520,9 @@ class ConformerEncoderLayer(nn.Module):
         src = src + self.dropout(src_att)
 
         # convolution module
-        conv, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
+        conv, _ = self.conv_module(
+            src, src_key_padding_mask=src_key_padding_mask
+        )
         src = src + self.dropout(conv)
 
         # feed forward module
@@ -757,7 +766,9 @@ class RelPositionalEncoding(torch.nn.Module):
         max_len: Maximum input length.
     """
 
-    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
+    def __init__(
+        self, d_model: int, dropout_rate: float, max_len: int = 5000
+    ) -> None:
         """Construct an PositionalEncoding object."""
         super(RelPositionalEncoding, self).__init__()
         self.d_model = d_model
@@ -773,7 +784,9 @@ class RelPositionalEncoding(torch.nn.Module):
             # the length of self.pe is 2 * input_len - 1
             if self.pe.size(1) >= x_size_1 * 2 - 1:
                 # Note: TorchScript doesn't implement operator== for torch.Device
-                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(
+                    x.device
+                ):
                     self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                 return
         # Suppose `i` means to the position of query vector and `j` means the
@@ -1060,9 +1073,9 @@ class RelPositionMultiheadAttention(nn.Module):
 
         if torch.equal(query, key) and torch.equal(key, value):
             # self-attention
-            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
-                3, dim=-1
-            )
+            q, k, v = nn.functional.linear(
+                query, in_proj_weight, in_proj_bias
+            ).chunk(3, dim=-1)
 
         elif torch.equal(key, value):
             # encoder-decoder attention
@@ -1131,25 +1144,33 @@ class RelPositionMultiheadAttention(nn.Module):
             if attn_mask.dim() == 2:
                 attn_mask = attn_mask.unsqueeze(0)
                 if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
-                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
+                    raise RuntimeError(
+                        "The size of the 2D attn_mask is not correct."
+                    )
             elif attn_mask.dim() == 3:
                 if list(attn_mask.size()) != [
                     bsz * num_heads,
                     query.size(0),
                     key.size(0),
                 ]:
-                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
+                    raise RuntimeError(
+                        "The size of the 3D attn_mask is not correct."
+                    )
             else:
                 raise RuntimeError(
-                    "attn_mask's dimension {} is not supported".format(attn_mask.dim())
+                    "attn_mask's dimension {} is not supported".format(
+                        attn_mask.dim()
+                    )
                 )
             # attn_mask's dim is 3 now.
 
         # convert ByteTensor key_padding_mask to bool
-        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
+        if (
+            key_padding_mask is not None
+            and key_padding_mask.dtype == torch.uint8
+        ):
             warnings.warn(
-                "Byte tensor for key_padding_mask is deprecated. Use bool tensor"
-                " instead."
+                "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
             )
             key_padding_mask = key_padding_mask.to(torch.bool)
 
@@ -1187,15 +1208,23 @@ class RelPositionMultiheadAttention(nn.Module):
         # first compute matrix a and matrix c
         # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
         k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
-        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(
+            q_with_bias_u, k
+        )  # (batch, head, time1, time2)
 
         # compute matrix b and matrix d
-        matrix_bd = torch.matmul(q_with_bias_v, p)  # (batch, head, time1, 2*time1-1)
+        matrix_bd = torch.matmul(
+            q_with_bias_v, p
+        )  # (batch, head, time1, 2*time1-1)
         matrix_bd = self.rel_shift(matrix_bd, left_context)
 
-        attn_output_weights = matrix_ac + matrix_bd  # (batch, head, time1, time2)
+        attn_output_weights = (
+            matrix_ac + matrix_bd
+        )  # (batch, head, time1, time2)
 
-        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
+        attn_output_weights = attn_output_weights.view(
+            bsz * num_heads, tgt_len, -1
+        )
 
         assert list(attn_output_weights.size()) == [
             bsz * num_heads,
@@ -1236,17 +1265,21 @@ class RelPositionMultiheadAttention(nn.Module):
         ):
             if attn_mask.size(0) != 1:
                 attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len)
-                combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2)
-            else:
-                # attn_mask.shape == (1, tgt_len, src_len)
-                combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze(
+                combined_mask = attn_mask | key_padding_mask.unsqueeze(
                     1
                 ).unsqueeze(2)
+            else:
+                # attn_mask.shape == (1, tgt_len, src_len)
+                combined_mask = attn_mask.unsqueeze(
+                    0
+                ) | key_padding_mask.unsqueeze(1).unsqueeze(2)
 
             attn_output_weights = attn_output_weights.view(
                 bsz, num_heads, tgt_len, src_len
             )
-            attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0)
+            attn_output_weights = attn_output_weights.masked_fill(
+                combined_mask, 0.0
+            )
             attn_output_weights = attn_output_weights.view(
                 bsz * num_heads, tgt_len, src_len
             )
@@ -1258,9 +1291,13 @@ class RelPositionMultiheadAttention(nn.Module):
         attn_output = torch.bmm(attn_output_weights, v)
         assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
         attn_output = (
-            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+            attn_output.transpose(0, 1)
+            .contiguous()
+            .view(tgt_len, bsz, embed_dim)
+        )
+        attn_output = nn.functional.linear(
+            attn_output, out_proj_weight, out_proj_bias
         )
-        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
 
         if need_weights:
             # average attention weights over heads
@@ -1393,12 +1430,16 @@ class ConvolutionModule(nn.Module):
                 # manualy padding self.lorder zeros to the left
                 x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
             else:
-                assert not self.training, "Cache should be None in training time"
+                assert (
+                    not self.training
+                ), "Cache should be None in training time"
                 assert cache.size(0) == self.lorder
                 x = torch.cat([cache.permute(1, 2, 0), x], dim=2)
                 if right_context > 0:
                     cache = x.permute(2, 0, 1)[
-                        -(self.lorder + right_context) : (-right_context),  # noqa
+                        -(self.lorder + right_context) : (  # noqa
+                            -right_context
+                        ),
                         ...,
                     ]
                 else:
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
index 3d66f9dc9..344e31283 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
@@ -160,24 +160,20 @@ def get_parser():
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch' and '--iter'"
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
     )
 
     parser.add_argument(
         "--use-averaged-model",
         type=str2bool,
         default=True,
-        help=(
-            "Whether to load averaged model. Currently it only supports "
-            "using --epoch. If True, it would decode with the averaged model "
-            "over the epoch range from `epoch-avg` (excluded) to `epoch`."
-            "Actually only the models with epoch number of `epoch-avg` and "
-            "`epoch` are loaded for averaging. "
-        ),
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
     )
 
     parser.add_argument(
@@ -248,7 +244,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -345,7 +342,9 @@ def decode_one_batch(
             simulate_streaming=True,
         )
     else:
-        encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+        encoder_out, encoder_out_lens = model.encoder(
+            x=feature, x_lens=feature_lens
+        )
 
     hyps = []
 
@@ -361,7 +360,10 @@ def decode_one_batch(
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -407,7 +409,11 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -478,7 +484,9 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -511,7 +519,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -580,12 +589,13 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg:
                 raise ValueError(
@@ -608,12 +618,13 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg + 1
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg + 1]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg + 1:
                 raise ValueError(
@@ -641,7 +652,7 @@ def main():
             filename_start = f"{params.exp_dir}/epoch-{start}.pt"
             filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
             logging.info(
-                "Calculating the averaged model over epoch range from "
+                f"Calculating the averaged model over epoch range from "
                 f"{start} (excluded) to {params.epoch}"
             )
             model.to(device)
@@ -709,7 +720,8 @@ def main():
         )
 
     dev_shards = [
-        str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
+        str(path)
+        for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
     ]
     cuts_dev_webdataset = CutSet.from_webdataset(
         dev_shards,
@@ -719,7 +731,8 @@ def main():
     )
 
     test_net_shards = [
-        str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
+        str(path)
+        for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
     ]
     cuts_test_net_webdataset = CutSet.from_webdataset(
         test_net_shards,
@@ -730,7 +743,9 @@ def main():
 
     test_meeting_shards = [
         str(path)
-        for path in sorted(glob.glob(os.path.join(test_meeting, "shared-*.tar")))
+        for path in sorted(
+            glob.glob(os.path.join(test_meeting, "shared-*.tar"))
+        )
     ]
     cuts_test_meeting_webdataset = CutSet.from_webdataset(
         test_meeting_shards,
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py
index e522943c0..386248554 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py
@@ -75,7 +75,9 @@ class DecodeStream(object):
         # encoder.streaming_forward
         self.done_frames: int = 0
 
-        self.pad_length = (params.right_context + 2) * params.subsampling_factor + 3
+        self.pad_length = (
+            params.right_context + 2
+        ) * params.subsampling_factor + 3
 
         if params.decoding_method == "greedy_search":
             self.hyp = [params.blank_id] * params.context_size
@@ -89,11 +91,13 @@ class DecodeStream(object):
             )
         elif params.decoding_method == "fast_beam_search":
             # The rnnt_decoding_stream for fast_beam_search.
-            self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream(
-                decoding_graph
+            self.rnnt_decoding_stream: k2.RnntDecodingStream = (
+                k2.RnntDecodingStream(decoding_graph)
             )
         else:
-            raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
+            raise ValueError(
+                f"Unsupported decoding method: {params.decoding_method}"
+            )
 
     @property
     def done(self) -> bool:
@@ -122,10 +126,13 @@ class DecodeStream(object):
         """Consume chunk_size frames of features"""
         chunk_length = chunk_size + self.pad_length
 
-        ret_length = min(self.num_frames - self.num_processed_frames, chunk_length)
+        ret_length = min(
+            self.num_frames - self.num_processed_frames, chunk_length
+        )
 
         ret_features = self.features[
-            self.num_processed_frames : self.num_processed_frames + ret_length  # noqa
+            self.num_processed_frames : self.num_processed_frames  # noqa
+            + ret_length
         ]
 
         self.num_processed_frames += chunk_size
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
index fb53f70ab..d0a7fd69f 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
@@ -90,20 +90,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=28,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -134,7 +131,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
     add_model_arguments(parser)
 
@@ -203,7 +201,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py
index 9834189d8..1b064c874 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py
@@ -80,11 +80,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -109,12 +107,10 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     parser.add_argument(
@@ -161,7 +157,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -192,9 +189,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -255,7 +253,9 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -280,7 +280,10 @@ def main():
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+    elif (
+        params.decoding_method == "greedy_search"
+        and params.max_sym_per_frame == 1
+    ):
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -332,7 +335,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py
index 810d94135..651aff6c9 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py
@@ -173,10 +173,14 @@ def modified_beam_search(
         log_probs_shape = k2.ragged.create_ragged_shape2(
             row_splits=row_splits, cached_tot_size=log_probs.numel()
         )
-        ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
+        ragged_log_probs = k2.RaggedTensor(
+            shape=log_probs_shape, value=log_probs
+        )
 
         for i in range(batch_size):
-            topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths)
+            topk_log_probs, topk_indexes = ragged_log_probs[i].topk(
+                num_active_paths
+            )
 
             with warnings.catch_warnings():
                 warnings.simplefilter("ignore")
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py
index 31a7fe605..ff96c6487 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py
@@ -119,24 +119,20 @@ def get_parser():
         "--avg",
         type=int,
         default=15,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch' and '--iter'"
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
     )
 
     parser.add_argument(
         "--use-averaged-model",
         type=str2bool,
         default=True,
-        help=(
-            "Whether to load averaged model. Currently it only supports "
-            "using --epoch. If True, it would decode with the averaged model "
-            "over the epoch range from `epoch-avg` (excluded) to `epoch`."
-            "Actually only the models with epoch number of `epoch-avg` and "
-            "`epoch` are loaded for averaging. "
-        ),
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
     )
 
     parser.add_argument(
@@ -205,7 +201,8 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
@@ -314,7 +311,9 @@ def decode_one_chunk(
     encoder_out = model.joiner.encoder_proj(encoder_out)
 
     if params.decoding_method == "greedy_search":
-        greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
+        greedy_search(
+            model=model, encoder_out=encoder_out, streams=decode_streams
+        )
     elif params.decoding_method == "fast_beam_search":
         processed_lens = processed_lens + encoder_out_lens
         fast_beam_search_one_best(
@@ -334,7 +333,9 @@ def decode_one_chunk(
             num_active_paths=params.num_active_paths,
         )
     else:
-        raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
+        raise ValueError(
+            f"Unsupported decoding method: {params.decoding_method}"
+        )
 
     states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)]
 
@@ -388,7 +389,9 @@ def decode_dataset(
     decode_results = []
     # Contain decode streams currently running.
     decode_streams = []
-    initial_states = model.encoder.get_init_state(params.left_context, device=device)
+    initial_states = model.encoder.get_init_state(
+        params.left_context, device=device
+    )
     for num, cut in enumerate(cuts):
         # each utterance has a DecodeStream.
         decode_stream = DecodeStream(
@@ -458,7 +461,9 @@ def decode_dataset(
     elif params.decoding_method == "modified_beam_search":
         key = f"num_active_paths_{params.num_active_paths}"
     else:
-        raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
+        raise ValueError(
+            f"Unsupported decoding method: {params.decoding_method}"
+        )
 
     return {key: decode_results}
 
@@ -494,7 +499,8 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir
+        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -559,12 +565,13 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg:
                 raise ValueError(
@@ -587,12 +594,13 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                : params.avg + 1
-            ]
+            filenames = find_checkpoints(
+                params.exp_dir, iteration=-params.iter
+            )[: params.avg + 1]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg + 1:
                 raise ValueError(
@@ -620,7 +628,7 @@ def main():
             filename_start = f"{params.exp_dir}/epoch-{start}.pt"
             filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
             logging.info(
-                "Calculating the averaged model over epoch range from "
+                f"Calculating the averaged model over epoch range from "
                 f"{start} (excluded) to {params.epoch}"
             )
             model.to(device)
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py
index 40c9665f7..2052e9da7 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py
@@ -98,7 +98,9 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+LRSchedulerType = Union[
+    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
+]
 
 
 def add_model_arguments(parser: argparse.ArgumentParser):
@@ -258,7 +260,8 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need to be changed.",
+        help="The initial learning rate.  This value should not need "
+        "to be changed.",
     )
 
     parser.add_argument(
@@ -281,45 +284,42 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; "
+        "2 means tri-gram",
     )
 
     parser.add_argument(
         "--prune-range",
         type=int,
         default=5,
-        help=(
-            "The prune range for rnnt loss, it means how many symbols(context)"
-            "we are using to compute the loss"
-        ),
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
     )
 
     parser.add_argument(
         "--lm-scale",
         type=float,
         default=0.25,
-        help=(
-            "The scale to smooth the loss with lm (output of prediction network) part."
-        ),
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
     )
 
     parser.add_argument(
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)part.",
+        help="The scale to smooth the loss with am (output of encoder network)"
+        "part.",
     )
 
     parser.add_argument(
         "--simple-loss-scale",
         type=float,
         default=0.5,
-        help=(
-            "To get pruning ranges, we will calculate a simple version"
-            "loss(joiner is just addition), this simple loss also uses for"
-            "training (as a regularization item). We will scale the simple loss"
-            "with this parameter before adding to the final loss."
-        ),
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
     )
 
     parser.add_argument(
@@ -665,7 +665,11 @@ def compute_loss(
      warmup: a floating point value which increases throughout training;
         values >= 1.0 are fully warmed up and have all modules present.
     """
-    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+    device = (
+        model.device
+        if isinstance(model, DDP)
+        else next(model.parameters()).device
+    )
     feature = batch["inputs"]
     # at entry, feature is (N, T, C)
     assert feature.ndim == 3
@@ -697,16 +701,23 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+            0.0
+            if warmup < 1.0
+            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+        )
+        loss = (
+            params.simple_loss_scale * simple_loss
+            + pruned_loss_scale * pruned_loss
         )
-        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
 
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+        info["frames"] = (
+            (feature_lens // params.subsampling_factor).sum().item()
+        )
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -830,7 +841,9 @@ def train_one_epoch(
             scaler.update()
             optimizer.zero_grad()
         except:  # noqa
-            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
+            display_and_save_batch(
+                batch, params=params, graph_compiler=graph_compiler
+            )
             raise
 
         if params.print_diagnostics and batch_idx == 5:
@@ -888,7 +901,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -1001,7 +1016,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2**22
+            2 ** 22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
@@ -1169,7 +1184,9 @@ def scan_pessimistic_batches_for_oom(
                     f"Failing criterion: {criterion} "
                     f"(={crit_values[criterion]}) ..."
                 )
-            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
+            display_and_save_batch(
+                batch, params=params, graph_compiler=graph_compiler
+            )
             raise
 
 
diff --git a/egs/yesno/ASR/local/compile_hlg.py b/egs/yesno/ASR/local/compile_hlg.py
index 7234ca929..f83be05cf 100755
--- a/egs/yesno/ASR/local/compile_hlg.py
+++ b/egs/yesno/ASR/local/compile_hlg.py
@@ -128,7 +128,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/yesno/ASR/local/compute_fbank_yesno.py b/egs/yesno/ASR/local/compute_fbank_yesno.py
index 75d95df68..9a4e8a36f 100755
--- a/egs/yesno/ASR/local/compute_fbank_yesno.py
+++ b/egs/yesno/ASR/local/compute_fbank_yesno.py
@@ -54,7 +54,9 @@ def compute_fbank_yesno():
         dataset_parts,
     )
 
-    extractor = Fbank(FbankConfig(sampling_rate=8000, num_mel_bins=num_mel_bins))
+    extractor = Fbank(
+        FbankConfig(sampling_rate=8000, num_mel_bins=num_mel_bins)
+    )
 
     with get_executor() as ex:  # Initialize the executor only once.
         for partition, m in manifests.items():
@@ -69,7 +71,9 @@ def compute_fbank_yesno():
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+                    cut_set
+                    + cut_set.perturb_speed(0.9)
+                    + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -83,7 +87,9 @@ def compute_fbank_yesno():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/yesno/ASR/tdnn/asr_datamodule.py b/egs/yesno/ASR/tdnn/asr_datamodule.py
index 21860d2f5..85e5f1358 100644
--- a/egs/yesno/ASR/tdnn/asr_datamodule.py
+++ b/egs/yesno/ASR/tdnn/asr_datamodule.py
@@ -56,12 +56,10 @@ class YesNoAsrDataModule(DataModule):
         super().add_arguments(parser)
         group = parser.add_argument_group(
             title="ASR data related options",
-            description=(
-                "These options are used for the preparation of "
-                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-                "effective batch sizes, sampling strategies, applied data "
-                "augmentations, etc."
-            ),
+            description="These options are used for the preparation of "
+            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+            "effective batch sizes, sampling strategies, applied data "
+            "augmentations, etc.",
         )
         group.add_argument(
             "--feature-dir",
@@ -73,91 +71,75 @@ class YesNoAsrDataModule(DataModule):
             "--max-duration",
             type=int,
             default=30.0,
-            help=(
-                "Maximum pooled recordings duration (seconds) in a "
-                "single batch. You can reduce it if it causes CUDA OOM."
-            ),
+            help="Maximum pooled recordings duration (seconds) in a "
+            "single batch. You can reduce it if it causes CUDA OOM.",
         )
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, the batches will come from buckets of "
-                "similar duration (saves padding frames)."
-            ),
+            help="When enabled, the batches will come from buckets of "
+            "similar duration (saves padding frames).",
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=10,
-            help=(
-                "The number of buckets for the DynamicBucketingSampler"
-                "(you might want to increase it for larger datasets)."
-            ),
+            help="The number of buckets for the DynamicBucketingSampler"
+            "(you might want to increase it for larger datasets).",
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, utterances (cuts) will be concatenated "
-                "to minimize the amount of padding."
-            ),
+            help="When enabled, utterances (cuts) will be concatenated "
+            "to minimize the amount of padding.",
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help=(
-                "Determines the maximum duration of a concatenated cut "
-                "relative to the duration of the longest cut in a batch."
-            ),
+            help="Determines the maximum duration of a concatenated cut "
+            "relative to the duration of the longest cut in a batch.",
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help=(
-                "The amount of padding (in seconds) inserted between "
-                "concatenated cuts. This padding is filled with noise when "
-                "noise augmentation is used."
-            ),
+            help="The amount of padding (in seconds) inserted between "
+            "concatenated cuts. This padding is filled with noise when "
+            "noise augmentation is used.",
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help=(
-                "When enabled, use on-the-fly cut mixing and feature "
-                "extraction. Will drop existing precomputed feature manifests "
-                "if available."
-            ),
+            help="When enabled, use on-the-fly cut mixing and feature "
+            "extraction. Will drop existing precomputed feature manifests "
+            "if available.",
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled (=default), the examples will be shuffled for each epoch."
-            ),
+            help="When enabled (=default), the examples will be "
+            "shuffled for each epoch.",
         )
         group.add_argument(
             "--return-cuts",
             type=str2bool,
             default=True,
-            help=(
-                "When enabled, each batch will have the "
-                "field: batch['supervisions']['cut'] with the cuts that "
-                "were used to construct it."
-            ),
+            help="When enabled, each batch will have the "
+            "field: batch['supervisions']['cut'] with the cuts that "
+            "were used to construct it.",
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that collect the batches.",
+            help="The number of training dataloader workers that "
+            "collect the batches.",
         )
 
     def train_dataloaders(self) -> DataLoader:
@@ -168,7 +150,7 @@ class YesNoAsrDataModule(DataModule):
         transforms = []
         if self.args.concatenate_cuts:
             logging.info(
-                "Using cut concatenation with duration factor "
+                f"Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
diff --git a/egs/yesno/ASR/tdnn/decode.py b/egs/yesno/ASR/tdnn/decode.py
index 41afe0404..9d4ab4b61 100755
--- a/egs/yesno/ASR/tdnn/decode.py
+++ b/egs/yesno/ASR/tdnn/decode.py
@@ -35,19 +35,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=14,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=2,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -204,7 +201,9 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -275,7 +274,9 @@ def main():
 
     logging.info(f"device: {device}")
 
-    HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
+    HLG = k2.Fsa.from_dict(
+        torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
+    )
     HLG = HLG.to(device)
     assert HLG.requires_grad is False
 
@@ -296,7 +297,9 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
+        torch.save(
+            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
+        )
         return
 
     model.to(device)
@@ -314,7 +317,9 @@ def main():
         word_table=lexicon.word_table,
     )
 
-    save_results(exp_dir=params.exp_dir, test_set_name="test_set", results=results)
+    save_results(
+        exp_dir=params.exp_dir, test_set_name="test_set", results=results
+    )
 
     logging.info("Done!")
 
diff --git a/egs/yesno/ASR/tdnn/pretrained.py b/egs/yesno/ASR/tdnn/pretrained.py
index 09a8672ae..14220be19 100755
--- a/egs/yesno/ASR/tdnn/pretrained.py
+++ b/egs/yesno/ASR/tdnn/pretrained.py
@@ -41,11 +41,9 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help=(
-            "Path to the checkpoint. "
-            "The checkpoint is assumed to be saved by "
-            "icefall.checkpoint.save_checkpoint()."
-        ),
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
     )
 
     parser.add_argument(
@@ -55,18 +53,18 @@ def get_parser():
         help="Path to words.txt",
     )
 
-    parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
+    parser.add_argument(
+        "--HLG", type=str, required=True, help="Path to HLG.pt."
+    )
 
     parser.add_argument(
         "sound_files",
         type=str,
         nargs="+",
-        help=(
-            "The input sound file(s) to transcribe. "
-            "Supported formats are those supported by torchaudio.load(). "
-            "For example, wav and flac are supported. "
-            "The sample rate has to be 16kHz."
-        ),
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
     )
 
     return parser
@@ -103,9 +101,10 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert (
-            sample_rate == expected_sample_rate
-        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. "
+            f"Given: {sample_rate}"
+        )
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -160,7 +159,9 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    features = pad_sequence(
+        features, batch_first=True, padding_value=math.log(1e-10)
+    )
 
     # Note: We don't use key padding mask for attention during decoding
     with torch.no_grad():
@@ -200,7 +201,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/yesno/ASR/tdnn/train.py b/egs/yesno/ASR/tdnn/train.py
index 335493491..f32a27f35 100755
--- a/egs/yesno/ASR/tdnn/train.py
+++ b/egs/yesno/ASR/tdnn/train.py
@@ -430,7 +430,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/egs/yesno/ASR/transducer/decode.py b/egs/yesno/ASR/transducer/decode.py
index de478334e..6714180db 100755
--- a/egs/yesno/ASR/transducer/decode.py
+++ b/egs/yesno/ASR/transducer/decode.py
@@ -48,19 +48,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=125,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=20,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
     parser.add_argument(
         "--exp-dir",
@@ -119,7 +116,9 @@ def decode_one_batch(
     # at entry, feature is (N, T, C)
     feature_lens = batch["supervisions"]["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    encoder_out, encoder_out_lens = model.encoder(
+        x=feature, x_lens=feature_lens
+    )
 
     hyps = []
     batch_size = encoder_out.size(0)
@@ -187,7 +186,9 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+            logging.info(
+                f"batch {batch_str}, cuts processed until now is {num_cuts}"
+            )
     return results
 
 
@@ -302,7 +303,9 @@ def main():
         model=model,
     )
 
-    save_results(exp_dir=params.exp_dir, test_set_name="test_set", results=results)
+    save_results(
+        exp_dir=params.exp_dir, test_set_name="test_set", results=results
+    )
 
     logging.info("Done!")
 
diff --git a/egs/yesno/ASR/transducer/train.py b/egs/yesno/ASR/transducer/train.py
index 88866ae81..deb92107d 100755
--- a/egs/yesno/ASR/transducer/train.py
+++ b/egs/yesno/ASR/transducer/train.py
@@ -430,7 +430,9 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/icefall/char_graph_compiler.py b/icefall/char_graph_compiler.py
index c31db6e4c..235160e14 100644
--- a/icefall/char_graph_compiler.py
+++ b/icefall/char_graph_compiler.py
@@ -71,7 +71,9 @@ class CharCtcTrainingGraphCompiler(object):
         for text in texts:
             text = re.sub(whitespace, "", text)
             sub_ids = [
-                self.token_table[txt] if txt in self.token_table else self.oov_id
+                self.token_table[txt]
+                if txt in self.token_table
+                else self.oov_id
                 for txt in text
             ]
             ids.append(sub_ids)
@@ -94,7 +96,9 @@ class CharCtcTrainingGraphCompiler(object):
         for text in texts:
             text = text.split("/")
             sub_ids = [
-                self.token_table[txt] if txt in self.token_table else self.oov_id
+                self.token_table[txt]
+                if txt in self.token_table
+                else self.oov_id
                 for txt in text
             ]
             ids.append(sub_ids)
diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py
index 8aa0a8eeb..5069b78e8 100644
--- a/icefall/checkpoint.py
+++ b/icefall/checkpoint.py
@@ -292,11 +292,15 @@ def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
     """
     checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt"))
     pattern = re.compile(r"checkpoint-([0-9]+).pt")
-    iter_checkpoints = [(int(pattern.search(c).group(1)), c) for c in checkpoints]
+    iter_checkpoints = [
+        (int(pattern.search(c).group(1)), c) for c in checkpoints
+    ]
     # iter_checkpoints is a list of tuples. Each tuple contains
     # two elements: (iteration_number, checkpoint-iteration_number.pt)
 
-    iter_checkpoints = sorted(iter_checkpoints, reverse=True, key=lambda x: x[0])
+    iter_checkpoints = sorted(
+        iter_checkpoints, reverse=True, key=lambda x: x[0]
+    )
     if iteration >= 0:
         ans = [ic[1] for ic in iter_checkpoints if ic[0] >= iteration]
     else:
@@ -465,5 +469,7 @@ def average_state_dict(
         v = state_dict_1[k]
         if torch.is_floating_point(v):
             v *= weight_1
-            v += state_dict_2[k].to(device=state_dict_1[k].device) * weight_2
+            v += (
+                state_dict_2[k].to(device=state_dict_1[k].device) * weight_2
+            )
             v *= scaling_factor
diff --git a/icefall/decode.py b/icefall/decode.py
index e4c614c4e..099e2d171 100644
--- a/icefall/decode.py
+++ b/icefall/decode.py
@@ -334,9 +334,13 @@ class Nbest(object):
         if hasattr(lattice, "aux_labels"):
             # delete token IDs as it is not needed
             del word_fsa.aux_labels
-            word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa)
+            word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(
+                word_fsa
+            )
         else:
-            word_fsa_with_epsilon_loops = k2.linear_fst_with_self_loops(word_fsa)
+            word_fsa_with_epsilon_loops = k2.linear_fst_with_self_loops(
+                word_fsa
+            )
 
         path_to_utt_map = self.shape.row_ids(1)
 
@@ -366,7 +370,9 @@ class Nbest(object):
         # path_lattice has word IDs as labels and token IDs as aux_labels
         path_lattice = k2.top_sort(k2.connect(path_lattice))
 
-        one_best = k2.shortest_path(path_lattice, use_double_scores=use_double_scores)
+        one_best = k2.shortest_path(
+            path_lattice, use_double_scores=use_double_scores
+        )
 
         one_best = k2.invert(one_best)
         # Now one_best has token IDs as labels and word IDs as aux_labels
@@ -436,7 +442,9 @@ class Nbest(object):
         scores_shape = self.fsa.arcs.shape().remove_axis(1)
         # scores_shape has axes [path][arc]
 
-        ragged_scores = k2.RaggedTensor(scores_shape, self.fsa.scores.contiguous())
+        ragged_scores = k2.RaggedTensor(
+            scores_shape, self.fsa.scores.contiguous()
+        )
 
         tot_scores = ragged_scores.sum()
 
@@ -475,7 +483,9 @@ def one_best_decoding(
             am_scores = saved_am_scores / lm_scale
             lattice.scores = am_scores + lattice.lm_scores
 
-            best_path = k2.shortest_path(lattice, use_double_scores=use_double_scores)
+            best_path = k2.shortest_path(
+                lattice, use_double_scores=use_double_scores
+            )
             key = f"lm_scale_{lm_scale}"
             ans[key] = best_path
         return ans
@@ -686,7 +696,9 @@ def rescore_with_n_best_list(
             logging.info(f"num_paths before decreasing: {num_paths}")
             num_paths = int(num_paths / 2)
             if loop_count >= max_loop_count or num_paths <= 0:
-                logging.info("Return None as the resulting lattice is too large.")
+                logging.info(
+                    "Return None as the resulting lattice is too large."
+                )
                 return None
             logging.info(
                 "This OOM is not an error. You can ignore it. "
@@ -793,9 +805,13 @@ def rescore_with_whole_lattice(
         except RuntimeError as e:
             logging.info(f"Caught exception:\n{e}\n")
             if loop_count >= max_loop_count:
-                logging.info("Return None as the resulting lattice is too large.")
+                logging.info(
+                    "Return None as the resulting lattice is too large."
+                )
                 return None
-            logging.info(f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}")
+            logging.info(
+                f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}"
+            )
             logging.info(
                 "This OOM is not an error. You can ignore it. "
                 "If your model does not converge well, or --max-duration "
@@ -807,7 +823,9 @@ def rescore_with_whole_lattice(
                 prune_th_list[loop_count],
                 True,
             )
-            logging.info(f"num_arcs after pruning: {inv_lattice.arcs.num_elements()}")
+            logging.info(
+                f"num_arcs after pruning: {inv_lattice.arcs.num_elements()}"
+            )
         loop_count += 1
 
     # lat has token IDs as labels
@@ -894,7 +912,9 @@ def rescore_with_attention_decoder(
             logging.info(f"num_paths before decreasing: {num_paths}")
             num_paths = int(num_paths / 2)
             if loop_count >= max_loop_count or num_paths <= 0:
-                logging.info("Return None as the resulting lattice is too large.")
+                logging.info(
+                    "Return None as the resulting lattice is too large."
+                )
                 return None
             logging.info(
                 "This OOM is not an error. You can ignore it. "
diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py
index 7b58ffbd4..b075aceac 100644
--- a/icefall/diagnostics.py
+++ b/icefall/diagnostics.py
@@ -19,7 +19,7 @@
 
 import random
 from dataclasses import dataclass
-from typing import List, Optional, Tuple
+from typing import Optional, Tuple, List
 
 import torch
 from torch import Tensor, nn
@@ -78,11 +78,11 @@ def get_tensor_stats(
     elif stats_type == "abs":
         x = x.abs()
     elif stats_type == "rms":
-        x = x**2
+        x = x ** 2
     elif stats_type == "positive":
         x = (x > 0).to(dtype=torch.float)
     else:
-        assert stats_type in ["value", "max", "min"]
+        assert stats_type in [ "value", "max", "min" ]
 
     sum_dims = [d for d in range(x.ndim) if d != dim]
     if len(sum_dims) > 0:
@@ -121,9 +121,7 @@ class TensorDiagnostic(object):
         self.name = name
         self.class_name = None  # will assign in accumulate()
 
-        self.stats = (
-            None  # we'll later assign a list to this data member.  It's a list of dict.
-        )
+        self.stats = None  # we'll later assign a list to this data member.  It's a list of dict.
 
         # the keys into self.stats[dim] are strings, whose values can be
         # "abs", "max", "min" ,"value", "positive", "rms", "value".
@@ -135,6 +133,7 @@ class TensorDiagnostic(object):
         # only adding a new element to the list if there was a different dim.
         # if the string in the key is "eigs", if we detect a length mismatch we put None as the value.
 
+
     def accumulate(self, x, class_name: Optional[str] = None):
         """
         Accumulate tensors.
@@ -186,12 +185,17 @@ class TensorDiagnostic(object):
                         done = True
                         break
                 if not done:
-                    if this_dim_stats[stats_type] != [] and stats_type == "eigs":
+                    if (
+                        this_dim_stats[stats_type] != []
+                        and stats_type == "eigs"
+                    ):
                         # >1 size encountered on this dim, e.g. it's a batch or time dimension,
                         # don't accumulat "eigs" stats type, it uses too much memory
                         this_dim_stats[stats_type] = None
                     else:
-                        this_dim_stats[stats_type].append(TensorAndCount(stats, count))
+                        this_dim_stats[stats_type].append(
+                            TensorAndCount(stats, count)
+                        )
 
     def print_diagnostics(self):
         """Print diagnostics for each dimension of the tensor."""
@@ -207,6 +211,7 @@ class TensorDiagnostic(object):
                     assert stats_type == "eigs"
                     continue
 
+
                 def get_count(count):
                     return 1 if stats_type in ["max", "min"] else count
 
@@ -216,8 +221,7 @@ class TensorDiagnostic(object):
                     # a dimension that has variable size in different nnet
                     # forwards, e.g. a time dimension in an ASR model.
                     stats = torch.cat(
-                        [x.tensor / get_count(x.count) for x in stats_list],
-                        dim=0,
+                        [x.tensor / get_count(x.count) for x in stats_list], dim=0
                     )
 
                 if stats_type == "eigs":
@@ -225,7 +229,9 @@ class TensorDiagnostic(object):
                         eigs, _ = torch.symeig(stats)
                         stats = eigs.abs().sqrt()
                     except:  # noqa
-                        print("Error getting eigenvalues, trying another method.")
+                        print(
+                            "Error getting eigenvalues, trying another method."
+                        )
                         eigs, _ = torch.eig(stats)
                         stats = eigs.abs().sqrt()
                         # sqrt so it reflects data magnitude, like stddev- not variance
@@ -236,9 +242,9 @@ class TensorDiagnostic(object):
 
                 # if `summarize` we print percentiles of the stats; else,
                 # we print out individual elements.
-                summarize = (len(stats_list) > 1) or self.opts.dim_is_summarized(
-                    stats.numel()
-                )
+                summarize = (
+                    len(stats_list) > 1
+                ) or self.opts.dim_is_summarized(stats.numel())
                 if summarize:  # usually `summarize` will be true
                     # print out percentiles.
                     stats = stats.sort()[0]
@@ -255,15 +261,15 @@ class TensorDiagnostic(object):
                     ans = stats.tolist()
                     ans = ["%.2g" % x for x in ans]
                     ans = "[" + " ".join(ans) + "]"
-                if stats_type in ["value", "rms", "eigs"]:
+                if stats_type in [ "value", "rms", "eigs" ]:
                     # This norm is useful because it is strictly less than the largest
                     # sqrt(eigenvalue) of the variance, which we print out, and shows,
                     # speaking in an approximate way, how much of that largest eigenvalue
                     # can be attributed to the mean of the distribution.
-                    norm = (stats**2).sum().sqrt().item()
+                    norm = (stats ** 2).sum().sqrt().item()
                     ans += f", norm={norm:.2g}"
                 mean = stats.mean().item()
-                rms = (stats**2).mean().sqrt().item()
+                rms = (stats ** 2).mean().sqrt().item()
                 ans += f", mean={mean:.3g}, rms={rms:.3g}"
 
                 # OK, "ans" contains the actual stats, e.g.
@@ -271,17 +277,17 @@ class TensorDiagnostic(object):
 
                 sizes = [x.tensor.shape[0] for x in stats_list]
                 size_str = (
-                    f"{sizes[0]}" if len(sizes) == 1 else f"{min(sizes)}..{max(sizes)}"
-                )
-                maybe_class_name = (
-                    f" type={self.class_name}," if self.class_name is not None else ""
+                    f"{sizes[0]}"
+                    if len(sizes) == 1
+                    else f"{min(sizes)}..{max(sizes)}"
                 )
+                maybe_class_name = f" type={self.class_name}," if self.class_name is not None else ""
                 print(
-                    f"module={self.name},{maybe_class_name} dim={dim}, size={size_str},"
-                    f" {stats_type} {ans}"
+                    f"module={self.name},{maybe_class_name} dim={dim}, size={size_str}, {stats_type} {ans}"
                 )
 
 
+
 class ModelDiagnostic(object):
     """This class stores diagnostics for all tensors in the torch.nn.Module.
 
@@ -339,32 +345,32 @@ def attach_diagnostics(
         # (matters for name, since the variable gets overwritten).
         # These closures don't really capture by value, only by
         # "the final value the variable got in the function" :-(
-        def forward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
+        def forward_hook(
+            _module, _input, _output, _model_diagnostic=ans, _name=name
+        ):
             if isinstance(_output, tuple) and len(_output) == 1:
                 _output = _output[0]
 
             if isinstance(_output, Tensor):
-                _model_diagnostic[f"{_name}.output"].accumulate(
-                    _output, class_name=type(_module).__name__
-                )
+                _model_diagnostic[f"{_name}.output"].accumulate(_output,
+                                                                class_name=type(_module).__name__)
             elif isinstance(_output, tuple):
                 for i, o in enumerate(_output):
-                    _model_diagnostic[f"{_name}.output[{i}]"].accumulate(
-                        o, class_name=type(_module).__name__
-                    )
+                    _model_diagnostic[f"{_name}.output[{i}]"].accumulate(o,
+                                                                         class_name=type(_module).__name__)
 
-        def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
+        def backward_hook(
+            _module, _input, _output, _model_diagnostic=ans, _name=name
+        ):
             if isinstance(_output, tuple) and len(_output) == 1:
                 _output = _output[0]
             if isinstance(_output, Tensor):
-                _model_diagnostic[f"{_name}.grad"].accumulate(
-                    _output, class_name=type(_module).__name__
-                )
+                _model_diagnostic[f"{_name}.grad"].accumulate(_output,
+                                                              class_name=type(_module).__name__)
             elif isinstance(_output, tuple):
                 for i, o in enumerate(_output):
-                    _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(
-                        o, class_name=type(_module).__name__
-                    )
+                    _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o,
+                                                                       class_name=type(_module).__name__)
 
         module.register_forward_hook(forward_hook)
         module.register_backward_hook(backward_hook)
diff --git a/icefall/dist.py b/icefall/dist.py
index 9df1c5bd1..7016beafb 100644
--- a/icefall/dist.py
+++ b/icefall/dist.py
@@ -29,7 +29,9 @@ def setup_dist(rank, world_size, master_port=None, use_ddp_launch=False):
         os.environ["MASTER_ADDR"] = "localhost"
 
     if "MASTER_PORT" not in os.environ:
-        os.environ["MASTER_PORT"] = "12354" if master_port is None else str(master_port)
+        os.environ["MASTER_PORT"] = (
+            "12354" if master_port is None else str(master_port)
+        )
 
     if use_ddp_launch is False:
         dist.init_process_group("nccl", rank=rank, world_size=world_size)
diff --git a/icefall/env.py b/icefall/env.py
index 373e9a9ff..8aeda6be2 100644
--- a/icefall/env.py
+++ b/icefall/env.py
@@ -53,7 +53,9 @@ def get_git_sha1():
             )
             > 0
         )
-        git_commit = git_commit + "-dirty" if dirty_commit else git_commit + "-clean"
+        git_commit = (
+            git_commit + "-dirty" if dirty_commit else git_commit + "-clean"
+        )
     except:  # noqa
         return None
 
diff --git a/icefall/graph_compiler.py b/icefall/graph_compiler.py
index e2ff03f61..570ed7d7a 100644
--- a/icefall/graph_compiler.py
+++ b/icefall/graph_compiler.py
@@ -75,7 +75,9 @@ class CtcTrainingGraphCompiler(object):
 
         # NOTE: k2.compose runs on CUDA only when treat_epsilons_specially
         # is False, so we add epsilon self-loops here
-        fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops(transcript_fsa)
+        fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops(
+            transcript_fsa
+        )
 
         fsa_with_self_loops = k2.arc_sort(fsa_with_self_loops)
 
diff --git a/icefall/hooks.py b/icefall/hooks.py
index 398a5f689..fbcf5e148 100644
--- a/icefall/hooks.py
+++ b/icefall/hooks.py
@@ -14,11 +14,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import logging
 import random
-
 import torch
 from torch import Tensor, nn
+import logging
 
 
 def register_inf_check_hooks(model: nn.Module) -> None:
@@ -57,7 +56,7 @@ def register_inf_check_hooks(model: nn.Module) -> None:
             if isinstance(_output, Tensor):
                 if not torch.isfinite(_output.to(torch.float32).sum()):
                     logging.warning(
-                        f"The sum of {_name}.grad is not finite"  # ": {_output}"
+                        f"The sum of {_name}.grad is not finite" # ": {_output}"
                     )
             elif isinstance(_output, tuple):
                 for i, o in enumerate(_output):
@@ -66,20 +65,28 @@ def register_inf_check_hooks(model: nn.Module) -> None:
                     if not isinstance(o, Tensor):
                         continue
                     if not torch.isfinite(o.to(torch.float32).sum()):
-                        logging.warning(f"The sum of {_name}.grad[{i}] is not finite")
+                        logging.warning(
+                            f"The sum of {_name}.grad[{i}] is not finite"
+                        )
 
         module.register_forward_hook(forward_hook)
         module.register_backward_hook(backward_hook)
 
+
     for name, parameter in model.named_parameters():
 
-        def param_backward_hook(grad, _name=name):
+        def param_backward_hook(
+                grad, _name=name
+        ):
             if not torch.isfinite(grad.to(torch.float32).sum()):
-                logging.warning(f"The sum of {_name}.param_grad is not finite")
+                logging.warning(
+                    f"The sum of {_name}.param_grad is not finite"
+                )
 
         parameter.register_hook(param_backward_hook)
 
 
+
 def _test_inf_check_hooks():
     model = nn.Sequential(nn.Linear(100, 50), nn.Linear(50, 80))
 
diff --git a/icefall/lexicon.py b/icefall/lexicon.py
index 22e1b78bb..80bd7c1ee 100644
--- a/icefall/lexicon.py
+++ b/icefall/lexicon.py
@@ -49,12 +49,18 @@ def read_lexicon(filename: str) -> List[Tuple[str, List[str]]]:
                 continue
 
             if len(a) < 2:
-                logging.info(f"Found bad line {line} in lexicon file {filename}")
-                logging.info("Every line is expected to contain at least 2 fields")
+                logging.info(
+                    f"Found bad line {line} in lexicon file {filename}"
+                )
+                logging.info(
+                    "Every line is expected to contain at least 2 fields"
+                )
                 sys.exit(1)
             word = a[0]
             if word == "":
-                logging.info(f"Found bad line {line} in lexicon file {filename}")
+                logging.info(
+                    f"Found bad line {line} in lexicon file {filename}"
+                )
                 logging.info(" should not be a valid word")
                 sys.exit(1)
 
@@ -113,7 +119,9 @@ def convert_lexicon_to_ragged(
     lexicon_tmp = read_lexicon(filename)
     lexicon = dict(lexicon_tmp)
     if len(lexicon_tmp) != len(lexicon):
-        raise RuntimeError("It's assumed that each word has a unique pronunciation")
+        raise RuntimeError(
+            "It's assumed that each word has a unique pronunciation"
+        )
 
     for i in range(disambig_id):
         w = word_table[i]
diff --git a/icefall/mmi.py b/icefall/mmi.py
index 16ed6e032..2c479fc2c 100644
--- a/icefall/mmi.py
+++ b/icefall/mmi.py
@@ -63,7 +63,10 @@ def _compute_mmi_loss_exact_optimized(
 
     # [0, num_fsas, 1, num_fsas, 2, num_fsas, ... ]
     num_den_graphs_indexes = (
-        torch.stack([num_graphs_indexes, den_graphs_indexes]).t().reshape(-1).to(device)
+        torch.stack([num_graphs_indexes, den_graphs_indexes])
+        .t()
+        .reshape(-1)
+        .to(device)
     )
 
     num_den_reordered_graphs = k2.index(num_den_graphs, num_den_graphs_indexes)
@@ -112,12 +115,20 @@ def _compute_mmi_loss_exact_non_optimized(
     num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=True)
 
     # TODO: pass output_beam as function argument
-    num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=beam_size)
-    den_lats = k2.intersect_dense(den_graphs, dense_fsa_vec, output_beam=beam_size)
+    num_lats = k2.intersect_dense(
+        num_graphs, dense_fsa_vec, output_beam=beam_size
+    )
+    den_lats = k2.intersect_dense(
+        den_graphs, dense_fsa_vec, output_beam=beam_size
+    )
 
-    num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
+    num_tot_scores = num_lats.get_tot_scores(
+        log_semiring=True, use_double_scores=True
+    )
 
-    den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
+    den_tot_scores = den_lats.get_tot_scores(
+        log_semiring=True, use_double_scores=True
+    )
 
     tot_scores = num_tot_scores - den_scale * den_tot_scores
 
@@ -157,9 +168,13 @@ def _compute_mmi_loss_pruned(
         max_active_states=10000,
     )
 
-    num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
+    num_tot_scores = num_lats.get_tot_scores(
+        log_semiring=True, use_double_scores=True
+    )
 
-    den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
+    den_tot_scores = den_lats.get_tot_scores(
+        log_semiring=True, use_double_scores=True
+    )
 
     tot_scores = num_tot_scores - den_scale * den_tot_scores
 
diff --git a/icefall/mmi_graph_compiler.py b/icefall/mmi_graph_compiler.py
index 9f680f83d..0d901227d 100644
--- a/icefall/mmi_graph_compiler.py
+++ b/icefall/mmi_graph_compiler.py
@@ -137,7 +137,9 @@ class MmiTrainingGraphCompiler(object):
             transcript_fsa
         )
 
-        transcript_fsa_with_self_loops = k2.arc_sort(transcript_fsa_with_self_loops)
+        transcript_fsa_with_self_loops = k2.arc_sort(
+            transcript_fsa_with_self_loops
+        )
 
         num = k2.compose(
             self.ctc_topo_P,
@@ -153,7 +155,9 @@ class MmiTrainingGraphCompiler(object):
 
         ctc_topo_P_vec = k2.create_fsa_vec([self.ctc_topo_P])
         if replicate_den:
-            indexes = torch.zeros(len(texts), dtype=torch.int32, device=self.device)
+            indexes = torch.zeros(
+                len(texts), dtype=torch.int32, device=self.device
+            )
             den = k2.index_fsa(ctc_topo_P_vec, indexes)
         else:
             den = ctc_topo_P_vec
diff --git a/icefall/rnn_lm/compute_perplexity.py b/icefall/rnn_lm/compute_perplexity.py
index 9a275bf28..550801a8f 100755
--- a/icefall/rnn_lm/compute_perplexity.py
+++ b/icefall/rnn_lm/compute_perplexity.py
@@ -46,19 +46,16 @@ def get_parser():
         "--epoch",
         type=int,
         default=49,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=20,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -197,7 +194,7 @@ def main():
 
     logging.info(f"Number of model parameters: {num_param}")
     logging.info(
-        "Number of model parameters (requires_grad): "
+        f"Number of model parameters (requires_grad): "
         f"{num_param_requires_grad} "
         f"({num_param_requires_grad/num_param_requires_grad*100}%)"
     )
diff --git a/icefall/rnn_lm/dataset.py b/icefall/rnn_lm/dataset.py
index 4bf982503..598e329c4 100644
--- a/icefall/rnn_lm/dataset.py
+++ b/icefall/rnn_lm/dataset.py
@@ -155,8 +155,12 @@ class LmDatasetCollate:
         sentence_tokens_with_sos = add_sos(sentence_tokens, self.sos_id)
         sentence_tokens_with_eos = add_eos(sentence_tokens, self.eos_id)
 
-        x = sentence_tokens_with_sos.pad(mode="constant", padding_value=self.blank_id)
-        y = sentence_tokens_with_eos.pad(mode="constant", padding_value=self.blank_id)
+        x = sentence_tokens_with_sos.pad(
+            mode="constant", padding_value=self.blank_id
+        )
+        y = sentence_tokens_with_eos.pad(
+            mode="constant", padding_value=self.blank_id
+        )
         sentence_token_lengths += 1  # plus 1 since we added a SOS
 
         return x.to(torch.int64), y.to(torch.int64), sentence_token_lengths
diff --git a/icefall/rnn_lm/export.py b/icefall/rnn_lm/export.py
index 2e878f5c8..094035fce 100644
--- a/icefall/rnn_lm/export.py
+++ b/icefall/rnn_lm/export.py
@@ -38,20 +38,17 @@ def get_parser():
         "--epoch",
         type=int,
         default=29,
-        help=(
-            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
-        ),
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=5,
-        help=(
-            "Number of checkpoints to average. Automatically select "
-            "consecutive checkpoints before the checkpoint specified by "
-            "'--epoch'. "
-        ),
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
     )
 
     parser.add_argument(
@@ -162,7 +159,9 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    formatter = (
+        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    )
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/icefall/rnn_lm/model.py b/icefall/rnn_lm/model.py
index 9eef88840..a6144727a 100644
--- a/icefall/rnn_lm/model.py
+++ b/icefall/rnn_lm/model.py
@@ -129,7 +129,9 @@ class RnnLmModel(torch.nn.Module):
         tokens_eos = add_eos(tokens, eos_id)
         sos_tokens_row_splits = sos_tokens.shape.row_splits(1)
 
-        sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
+        sentence_lengths = (
+            sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
+        )
 
         x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id)
         y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id)
@@ -159,12 +161,12 @@ class RnnLmModel(torch.nn.Module):
         if state:
             h, c = state
         else:
-            h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(
-                device
-            )
-            c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(
-                device
-            )
+            h = torch.zeros(
+                self.rnn.num_layers, batch_size, self.rnn.input_size
+            ).to(device)
+            c = torch.zeros(
+                self.rnn.num_layers, batch_size, self.rnn.input_size
+            ).to(device)
 
         embedding = self.input_embedding(tokens)
         rnn_out, states = self.rnn(embedding, (h, c))
@@ -179,8 +181,12 @@ class RnnLmModel(torch.nn.Module):
         if state:
             h, c = state
         else:
-            h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size)
-            c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size)
+            h = torch.zeros(
+                self.rnn.num_layers, batch_size, self.rnn.input_size
+            )
+            c = torch.zeros(
+                self.rnn.num_layers, batch_size, self.rnn.input_size
+            )
 
         device = next(self.parameters()).device
 
@@ -188,7 +194,9 @@ class RnnLmModel(torch.nn.Module):
         tokens_eos = add_eos(tokens, eos_id)
         sos_tokens_row_splits = sos_tokens.shape.row_splits(1)
 
-        sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
+        sentence_lengths = (
+            sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
+        )
 
         x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id)
         y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id)
diff --git a/icefall/rnn_lm/train.py b/icefall/rnn_lm/train.py
index e17b50332..bb5f03fb9 100755
--- a/icefall/rnn_lm/train.py
+++ b/icefall/rnn_lm/train.py
@@ -446,13 +446,17 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                tot_loss.write_summary(
+                    tb_writer, "train/tot_", params.batch_idx_train
+                )
 
                 tb_writer.add_scalar(
                     "train/current_ppl", this_batch_ppl, params.batch_idx_train
                 )
 
-                tb_writer.add_scalar("train/tot_ppl", tot_ppl, params.batch_idx_train)
+                tb_writer.add_scalar(
+                    "train/tot_ppl", tot_ppl, params.batch_idx_train
+                )
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -467,7 +471,8 @@ def train_one_epoch(
 
             valid_ppl = math.exp(valid_info["loss"] / valid_info["frames"])
             logging.info(
-                f"Epoch {params.cur_epoch}, validation: {valid_info}, ppl: {valid_ppl}"
+                f"Epoch {params.cur_epoch}, validation: {valid_info}, "
+                f"ppl: {valid_ppl}"
             )
 
             if tb_writer is not None:
diff --git a/icefall/shared/make_kn_lm.py b/icefall/shared/make_kn_lm.py
index a3bf1ef4c..c2edd823e 100755
--- a/icefall/shared/make_kn_lm.py
+++ b/icefall/shared/make_kn_lm.py
@@ -15,50 +15,30 @@
 # The data structure is based on: kaldi/egs/wsj/s5/utils/lang/make_phone_lm.py
 # The smoothing algorithm is based on: http://www.speech.sri.com/projects/srilm/manpages/ngram-discount.7.html
 
-import argparse
-import io
-import math
+import sys
 import os
 import re
-import sys
+import io
+import math
+import argparse
 from collections import Counter, defaultdict
 
-parser = argparse.ArgumentParser(
-    description="""
+
+parser = argparse.ArgumentParser(description="""
     Generate kneser-ney language model as arpa format. By default,
     it will read the corpus from standard input, and output to standard output.
-    """
-)
-parser.add_argument(
-    "-ngram-order",
-    type=int,
-    default=4,
-    choices=[2, 3, 4, 5, 6, 7],
-    help="Order of n-gram",
-)
+    """)
+parser.add_argument("-ngram-order", type=int, default=4, choices=[2, 3, 4, 5, 6, 7], help="Order of n-gram")
 parser.add_argument("-text", type=str, default=None, help="Path to the corpus file")
-parser.add_argument(
-    "-lm",
-    type=str,
-    default=None,
-    help="Path to output arpa file for language models",
-)
-parser.add_argument(
-    "-verbose",
-    type=int,
-    default=0,
-    choices=[0, 1, 2, 3, 4, 5],
-    help="Verbose level",
-)
+parser.add_argument("-lm", type=str, default=None, help="Path to output arpa file for language models")
+parser.add_argument("-verbose", type=int, default=0, choices=[0, 1, 2, 3, 4, 5], help="Verbose level")
 args = parser.parse_args()
 
-default_encoding = (
-    "latin-1"  # For encoding-agnostic scripts, we assume byte stream as input.
-)
-# Need to be very careful about the use of strip() and split()
-# in this case, because there is a latin-1 whitespace character
-# (nbsp) which is part of the unicode encoding range.
-# Ref: kaldi/egs/wsj/s5/utils/lang/bpe/prepend_words.py @ 69cd717
+default_encoding = "latin-1"  # For encoding-agnostic scripts, we assume byte stream as input.
+                              # Need to be very careful about the use of strip() and split()
+                              # in this case, because there is a latin-1 whitespace character
+                              # (nbsp) which is part of the unicode encoding range.
+                              # Ref: kaldi/egs/wsj/s5/utils/lang/bpe/prepend_words.py @ 69cd717
 strip_chars = " \t\r\n"
 whitespace = re.compile("[ \t]+")
 
@@ -72,9 +52,7 @@ class CountsForHistory:
         # The 'lambda: defaultdict(float)' is an anonymous function taking no
         # arguments that returns a new defaultdict(float).
         self.word_to_count = defaultdict(int)
-        self.word_to_context = defaultdict(
-            set
-        )  # using a set to count the number of unique contexts
+        self.word_to_context = defaultdict(set)  # using a set to count the number of unique contexts
         self.word_to_f = dict()  # discounted probability
         self.word_to_bow = dict()  # back-off weight
         self.total_count = 0
@@ -84,15 +62,10 @@ class CountsForHistory:
 
     def __str__(self):
         # e.g. returns ' total=12: 3->4, 4->6, -1->2'
-        return " total={0}: {1}".format(
+        return ' total={0}: {1}'.format(
             str(self.total_count),
-            ", ".join(
-                [
-                    "{0} -> {1}".format(word, count)
-                    for word, count in self.word_to_count.items()
-                ]
-            ),
-        )
+            ', '.join(['{0} -> {1}'.format(word, count)
+                      for word, count in self.word_to_count.items()]))
 
     def add_count(self, predicted_word, context_word, count):
         assert count >= 0
@@ -112,7 +85,7 @@ class NgramCounts:
     # accumulating the 4-gram count for the '8' in the sequence '5 6 7 8', we'd
     # do as follows: self.counts[3][[5,6,7]][8] += 1.0 where the [3] indexes an
     # array, the [[5,6,7]] indexes a dict, and the [8] indexes a dict.
-    def __init__(self, ngram_order, bos_symbol="", eos_symbol=""):
+    def __init__(self, ngram_order, bos_symbol='', eos_symbol=''):
         assert ngram_order >= 2
 
         self.ngram_order = ngram_order
@@ -130,48 +103,39 @@ class NgramCounts:
     # would be (6,7,8) and 'predicted_word' would be 9; 'count' would be
     # 1.
     def add_count(self, history, predicted_word, context_word, count):
-        self.counts[len(history)][history].add_count(
-            predicted_word, context_word, count
-        )
+        self.counts[len(history)][history].add_count(predicted_word, context_word, count)
 
     # 'line' is a string containing a sequence of integer word-ids.
     # This function adds the un-smoothed counts from this line of text.
     def add_raw_counts_from_line(self, line):
-        if line == "":
+        if line == '':
             words = [self.bos_symbol, self.eos_symbol]
         else:
             words = [self.bos_symbol] + whitespace.split(line) + [self.eos_symbol]
 
         for i in range(len(words)):
-            for n in range(1, self.ngram_order + 1):
+            for n in range(1, self.ngram_order+1):
                 if i + n > len(words):
                     break
-                ngram = words[i : i + n]
+                ngram = words[i: i + n]
                 predicted_word = ngram[-1]
-                history = tuple(ngram[:-1])
+                history = tuple(ngram[: -1])
                 if i == 0 or n == self.ngram_order:
                     context_word = None
                 else:
-                    context_word = words[i - 1]
+                    context_word = words[i-1]
 
                 self.add_count(history, predicted_word, context_word, 1)
 
     def add_raw_counts_from_standard_input(self):
         lines_processed = 0
-        infile = io.TextIOWrapper(
-            sys.stdin.buffer, encoding=default_encoding
-        )  # byte stream as input
+        infile = io.TextIOWrapper(sys.stdin.buffer, encoding=default_encoding)  # byte stream as input
         for line in infile:
             line = line.strip(strip_chars)
             self.add_raw_counts_from_line(line)
             lines_processed += 1
         if lines_processed == 0 or args.verbose > 0:
-            print(
-                "make_phone_lm.py: processed {0} lines of input".format(
-                    lines_processed
-                ),
-                file=sys.stderr,
-            )
+            print("make_phone_lm.py: processed {0} lines of input".format(lines_processed), file=sys.stderr)
 
     def add_raw_counts_from_file(self, filename):
         lines_processed = 0
@@ -181,12 +145,7 @@ class NgramCounts:
                 self.add_raw_counts_from_line(line)
                 lines_processed += 1
         if lines_processed == 0 or args.verbose > 0:
-            print(
-                "make_phone_lm.py: processed {0} lines of input".format(
-                    lines_processed
-                ),
-                file=sys.stderr,
-            )
+            print("make_phone_lm.py: processed {0} lines of input".format(lines_processed), file=sys.stderr)
 
     def cal_discounting_constants(self):
         # For each order N of N-grams, we calculate discounting constant D_N = n1_N / (n1_N + 2 * n2_N),
@@ -194,11 +153,9 @@ class NgramCounts:
         # This constant is used similarly to absolute discounting.
         # Return value: d is a list of floats, where d[N+1] = D_N
 
-        self.d = [
-            0
-        ]  # for the lowest order, i.e., 1-gram, we do not need to discount, thus the constant is 0
-        # This is a special case: as we currently assumed having seen all vocabularies in the dictionary,
-        # but perhaps this is not the case for some other scenarios.
+        self.d = [0]  # for the lowest order, i.e., 1-gram, we do not need to discount, thus the constant is 0
+                      # This is a special case: as we currently assumed having seen all vocabularies in the dictionary,
+                      # but perhaps this is not the case for some other scenarios.
         for n in range(1, self.ngram_order):
             this_order_counts = self.counts[n]
             n1 = 0
@@ -208,11 +165,9 @@ class NgramCounts:
                 n1 += stat[1]
                 n2 += stat[2]
             assert n1 + 2 * n2 > 0
-            self.d.append(
-                max(0.1, n1 * 1.0) / (n1 + 2 * n2)
-            )  # We are doing this max(0.001, xxx) to avoid zero discounting constant D due to n1=0,
-            # which could happen if the number of symbols is small.
-            # Otherwise, zero discounting constant can cause division by zero in computing BOW.
+            self.d.append(max(0.1, n1 * 1.0) / (n1 + 2 * n2))   # We are doing this max(0.001, xxx) to avoid zero discounting constant D due to n1=0, 
+                                                                # which could happen if the number of symbols is small.
+                                                                # Otherwise, zero discounting constant can cause division by zero in computing BOW.
 
     def cal_f(self):
         # f(a_z) is a probability distribution of word sequence a_z.
@@ -227,9 +182,7 @@ class NgramCounts:
         this_order_counts = self.counts[n]
         for hist, counts_for_hist in this_order_counts.items():
             for w, c in counts_for_hist.word_to_count.items():
-                counts_for_hist.word_to_f[w] = (
-                    max((c - self.d[n]), 0) * 1.0 / counts_for_hist.total_count
-                )
+                counts_for_hist.word_to_f[w] = max((c - self.d[n]), 0) * 1.0 / counts_for_hist.total_count
 
         # lower order N-grams
         for n in range(0, self.ngram_order - 1):
@@ -243,17 +196,11 @@ class NgramCounts:
                 if n_star_star != 0:
                     for w in counts_for_hist.word_to_count.keys():
                         n_star_z = len(counts_for_hist.word_to_context[w])
-                        counts_for_hist.word_to_f[w] = (
-                            max((n_star_z - self.d[n]), 0) * 1.0 / n_star_star
-                        )
+                        counts_for_hist.word_to_f[w] = max((n_star_z - self.d[n]), 0) * 1.0 / n_star_star
                 else:  # patterns begin with , they do not have "modified count", so use raw count instead
                     for w in counts_for_hist.word_to_count.keys():
                         n_star_z = counts_for_hist.word_to_count[w]
-                        counts_for_hist.word_to_f[w] = (
-                            max((n_star_z - self.d[n]), 0)
-                            * 1.0
-                            / counts_for_hist.total_count
-                        )
+                        counts_for_hist.word_to_f[w] = max((n_star_z - self.d[n]), 0) * 1.0 / counts_for_hist.total_count
 
     def cal_bow(self):
         # Backoff weights are only necessary for ngrams which form a prefix of a longer ngram.
@@ -293,18 +240,12 @@ class NgramCounts:
                         sum_z1_f_z = 0
                         _ = a_[1:]
                         _counts_for_hist = self.counts[len(_)][_]
-                        for (
-                            u
-                        ) in (
-                            a_counts_for_hist.word_to_count.keys()
-                        ):  # Should be careful here: what is Z1
+                        for u in a_counts_for_hist.word_to_count.keys():  # Should be careful here: what is Z1
                             sum_z1_f_z += _counts_for_hist.word_to_f[u]
 
                         if sum_z1_f_z < 1:
                             # assert sum_z1_f_a_z < 1
-                            counts_for_hist.word_to_bow[w] = (1.0 - sum_z1_f_a_z) / (
-                                1.0 - sum_z1_f_z
-                            )
+                            counts_for_hist.word_to_bow[w] = (1.0 - sum_z1_f_a_z) / (1.0 - sum_z1_f_z)
                         else:
                             counts_for_hist.word_to_bow[w] = None
 
@@ -318,9 +259,7 @@ class NgramCounts:
                     ngram = " ".join(hist) + " " + w
                     ngram = ngram.strip(strip_chars)
 
-                    res.append(
-                        "{0}\t{1}".format(ngram, counts_for_hist.word_to_count[w])
-                    )
+                    res.append("{0}\t{1}".format(ngram, counts_for_hist.word_to_count[w]))
         res.sort(reverse=True)
         for r in res:
             print(r)
@@ -383,40 +322,27 @@ class NgramCounts:
                     if bow is None:
                         res.append("{1}\t{0}".format(ngram, math.log(f, 10)))
                     else:
-                        res.append(
-                            "{1}\t{0}\t{2}".format(
-                                ngram, math.log(f, 10), math.log(bow, 10)
-                            )
-                        )
+                        res.append("{1}\t{0}\t{2}".format(ngram, math.log(f, 10), math.log(bow, 10)))
         res.sort(reverse=True)
         for r in res:
             print(r)
 
-    def print_as_arpa(
-        self, fout=io.TextIOWrapper(sys.stdout.buffer, encoding="latin-1")
-    ):
+    def print_as_arpa(self, fout=io.TextIOWrapper(sys.stdout.buffer, encoding='latin-1')):
         # print as ARPA format.
 
-        print("\\data\\", file=fout)
+        print('\\data\\', file=fout)
         for hist_len in range(self.ngram_order):
             # print the number of n-grams.
-            print(
-                "ngram {0}={1}".format(
-                    hist_len + 1,
-                    sum(
-                        [
-                            len(counts_for_hist.word_to_f)
-                            for counts_for_hist in self.counts[hist_len].values()
-                        ]
-                    ),
-                ),
-                file=fout,
+            print('ngram {0}={1}'.format(
+                hist_len + 1,
+                sum([len(counts_for_hist.word_to_f) for counts_for_hist in self.counts[hist_len].values()])),
+                file=fout
             )
 
-        print("", file=fout)
+        print('', file=fout)
 
         for hist_len in range(self.ngram_order):
-            print("\\{0}-grams:".format(hist_len + 1), file=fout)
+            print('\\{0}-grams:'.format(hist_len + 1), file=fout)
 
             this_order_counts = self.counts[hist_len]
             for hist, counts_for_hist in this_order_counts.items():
@@ -428,12 +354,12 @@ class NgramCounts:
                     if prob == 0:  # f() is always 0
                         prob = 1e-99
 
-                    line = "{0}\t{1}".format("%.7f" % math.log10(prob), " ".join(ngram))
+                    line = '{0}\t{1}'.format('%.7f' % math.log10(prob), ' '.join(ngram))
                     if bow is not None:
-                        line += "\t{0}".format("%.7f" % math.log10(bow))
+                        line += '\t{0}'.format('%.7f' % math.log10(bow))
                     print(line, file=fout)
-            print("", file=fout)
-        print("\\end\\", file=fout)
+            print('', file=fout)
+        print('\\end\\', file=fout)
 
 
 if __name__ == "__main__":
@@ -453,5 +379,5 @@ if __name__ == "__main__":
     if args.lm is None:
         ngram_counts.print_as_arpa()
     else:
-        with open(args.lm, "w", encoding=default_encoding) as f:
+        with open(args.lm, 'w', encoding=default_encoding) as f:
             ngram_counts.print_as_arpa(fout=f)
diff --git a/icefall/utils.py b/icefall/utils.py
index 785bd80f9..143c79497 100644
--- a/icefall/utils.py
+++ b/icefall/utils.py
@@ -130,7 +130,9 @@ def setup_logger(
         formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s"  # noqa
         log_filename = f"{log_filename}-{date_time}-{rank}"
     else:
-        formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+        formatter = (
+            "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+        )
         log_filename = f"{log_filename}-{date_time}"
 
     os.makedirs(os.path.dirname(log_filename), exist_ok=True)
@@ -201,7 +203,7 @@ def encode_supervisions(
                 supervisions["num_frames"],
                 subsampling_factor,
                 rounding_mode="floor",
-            ),
+            )
         ),
         1,
     ).to(torch.int32)
@@ -286,9 +288,13 @@ def get_texts_with_timestamp(
     """
     if isinstance(best_paths.aux_labels, k2.RaggedTensor):
         all_aux_shape = (
-            best_paths.arcs.shape().remove_axis(1).compose(best_paths.aux_labels.shape)
+            best_paths.arcs.shape()
+            .remove_axis(1)
+            .compose(best_paths.aux_labels.shape)
+        )
+        all_aux_labels = k2.RaggedTensor(
+            all_aux_shape, best_paths.aux_labels.values
         )
-        all_aux_labels = k2.RaggedTensor(all_aux_shape, best_paths.aux_labels.values)
         # remove 0's and -1's.
         aux_labels = best_paths.aux_labels.remove_values_leq(0)
         # TODO: change arcs.shape() to arcs.shape
@@ -357,7 +363,9 @@ def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]:
     # arc.shape() has axes [fsa][state][arc], we remove "state"-axis here
     token_shape = best_paths.arcs.shape().remove_axis(1)
     # token_shape has axes [fsa][arc]
-    tokens = k2.RaggedTensor(token_shape, getattr(best_paths, kind).contiguous())
+    tokens = k2.RaggedTensor(
+        token_shape, getattr(best_paths, kind).contiguous()
+    )
     tokens = tokens.remove_values_eq(-1)
     return tokens.tolist()
 
@@ -578,7 +586,9 @@ def write_error_stats(
             f"{cut_id}:\t"
             + " ".join(
                 (
-                    ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
+                    ref_word
+                    if ref_word == hyp_word
+                    else f"({ref_word}->{hyp_word})"
                     for ref_word, hyp_word in ali
                 )
             ),
@@ -588,7 +598,9 @@ def write_error_stats(
     print("", file=f)
     print("SUBSTITUTIONS: count ref -> hyp", file=f)
 
-    for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
+    for count, (ref, hyp) in sorted(
+        [(v, k) for k, v in subs.items()], reverse=True
+    ):
         print(f"{count}   {ref} -> {hyp}", file=f)
 
     print("", file=f)
@@ -602,7 +614,9 @@ def write_error_stats(
         print(f"{count}   {hyp}", file=f)
 
     print("", file=f)
-    print("PER-WORD STATS: word  corr tot_errs count_in_ref count_in_hyp", file=f)
+    print(
+        "PER-WORD STATS: word  corr tot_errs count_in_ref count_in_hyp", file=f
+    )
     for _, word, counts in sorted(
         [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
     ):
@@ -777,7 +791,9 @@ def write_error_stats_with_timestamps(
             f"{cut_id}:\t"
             + " ".join(
                 (
-                    ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
+                    ref_word
+                    if ref_word == hyp_word
+                    else f"({ref_word}->{hyp_word})"
                     for ref_word, hyp_word in ali
                 )
             ),
@@ -787,7 +803,9 @@ def write_error_stats_with_timestamps(
     print("", file=f)
     print("SUBSTITUTIONS: count ref -> hyp", file=f)
 
-    for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
+    for count, (ref, hyp) in sorted(
+        [(v, k) for k, v in subs.items()], reverse=True
+    ):
         print(f"{count}   {ref} -> {hyp}", file=f)
 
     print("", file=f)
@@ -801,7 +819,9 @@ def write_error_stats_with_timestamps(
         print(f"{count}   {hyp}", file=f)
 
     print("", file=f)
-    print("PER-WORD STATS: word  corr tot_errs count_in_ref count_in_hyp", file=f)
+    print(
+        "PER-WORD STATS: word  corr tot_errs count_in_ref count_in_hyp", file=f
+    )
     for _, word, counts in sorted(
         [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
     ):
@@ -871,7 +891,9 @@ class MetricsTracker(collections.defaultdict):
             if k == "frames" or k == "utterances":
                 continue
             norm_value = (
-                float(v) / num_frames if "utt_" not in k else float(v) / num_utterances
+                float(v) / num_frames
+                if "utt_" not in k
+                else float(v) / num_utterances
             )
             ans.append((k, norm_value))
         return ans
@@ -905,7 +927,9 @@ class MetricsTracker(collections.defaultdict):
             tb_writer.add_scalar(prefix + k, v, batch_idx)
 
 
-def concat(ragged: k2.RaggedTensor, value: int, direction: str) -> k2.RaggedTensor:
+def concat(
+    ragged: k2.RaggedTensor, value: int, direction: str
+) -> k2.RaggedTensor:
     """Prepend a value to the beginning of each sublist or append a value.
     to the end of each sublist.
 
@@ -951,8 +975,8 @@ def concat(ragged: k2.RaggedTensor, value: int, direction: str) -> k2.RaggedTens
         ans = k2.ragged.cat([ragged, pad], axis=1)
     else:
         raise ValueError(
-            f'Unsupported direction: {direction}. "             "Expect either "left"'
-            ' or "right"'
+            f'Unsupported direction: {direction}. " \
+            "Expect either "left" or "right"'
         )
     return ans
 
@@ -1077,7 +1101,9 @@ def linf_norm(x):
     return torch.max(torch.abs(x))
 
 
-def measure_weight_norms(model: nn.Module, norm: str = "l2") -> Dict[str, float]:
+def measure_weight_norms(
+    model: nn.Module, norm: str = "l2"
+) -> Dict[str, float]:
     """
     Compute the norms of the model's parameters.
 
@@ -1100,7 +1126,9 @@ def measure_weight_norms(model: nn.Module, norm: str = "l2") -> Dict[str, float]
         return norms
 
 
-def measure_gradient_norms(model: nn.Module, norm: str = "l1") -> Dict[str, float]:
+def measure_gradient_norms(
+    model: nn.Module, norm: str = "l1"
+) -> Dict[str, float]:
     """
     Compute the norms of the gradients for each of model's parameters.
 
@@ -1385,7 +1413,9 @@ def parse_hyp_and_timestamp(
         use_word_table = True
 
     for i in range(N):
-        time = convert_timestamp(res.timestamps[i], subsampling_factor, frame_shift_ms)
+        time = convert_timestamp(
+            res.timestamps[i], subsampling_factor, frame_shift_ms
+        )
         if use_word_table:
             words = [word_table[i] for i in res.hyps[i]]
         else:
diff --git a/pyproject.toml b/pyproject.toml
index 3183055d4..b4f8c3377 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -3,7 +3,7 @@ profile = "black"
 skip = ["icefall/__init__.py"]
 
 [tool.black]
-line-length = 88
+line-length = 80
 exclude = '''
 /(
     \.git
diff --git a/setup.py b/setup.py
index ccd2503ff..6c720e121 100644
--- a/setup.py
+++ b/setup.py
@@ -1,8 +1,7 @@
 #!/usr/bin/env python3
 
-from pathlib import Path
-
 from setuptools import find_packages, setup
+from pathlib import Path
 
 icefall_dir = Path(__file__).parent
 install_requires = (icefall_dir / "requirements.txt").read_text().splitlines()
diff --git a/test/test_checkpoint.py b/test/test_checkpoint.py
index 34e829642..511a11c23 100644
--- a/test/test_checkpoint.py
+++ b/test/test_checkpoint.py
@@ -20,7 +20,11 @@ import pytest
 import torch
 import torch.nn as nn
 
-from icefall.checkpoint import average_checkpoints, load_checkpoint, save_checkpoint
+from icefall.checkpoint import (
+    average_checkpoints,
+    load_checkpoint,
+    save_checkpoint,
+)
 
 
 @pytest.fixture
diff --git a/test/test_decode.py b/test/test_decode.py
index 4c2e192a7..97964ac67 100644
--- a/test/test_decode.py
+++ b/test/test_decode.py
@@ -23,7 +23,6 @@ You can run this file in one of the two ways:
 """
 
 import k2
-
 from icefall.decode import Nbest
 
 
diff --git a/test/test_graph_compiler.py b/test/test_graph_compiler.py
index 10443cf22..ccfb57d49 100644
--- a/test/test_graph_compiler.py
+++ b/test/test_graph_compiler.py
@@ -154,7 +154,9 @@ class TestCtcTrainingGraphCompiler(object):
         fsas = k2.Fsa.from_fsas([fsa1, fsa2])
 
         decoding_graph = k2.arc_sort(decoding_graph)
-        lattice = k2.intersect(decoding_graph, fsas, treat_epsilons_specially=False)
+        lattice = k2.intersect(
+            decoding_graph, fsas, treat_epsilons_specially=False
+        )
         lattice = k2.connect(lattice)
 
         aux_labels0 = lattice[0].aux_labels[:-1]
diff --git a/test/test_utils.py b/test/test_utils.py
index 31f06bd51..6a9ce7853 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -50,7 +50,9 @@ def test_encode_supervisions(sup):
     assert torch.all(
         torch.eq(
             supervision_segments,
-            torch.tensor([[1, 0, 30 // 4], [0, 0, 20 // 4], [2, 9 // 4, 10 // 4]]),
+            torch.tensor(
+                [[1, 0, 30 // 4], [0, 0, 20 // 4], [2, 9 // 4, 10 // 4]]
+            ),
         )
     )
     assert texts == ["two", "one", "three"]