From d110b04ad389134c82fa314e3aafc7b40043efb0 Mon Sep 17 00:00:00 2001 From: Desh Raj Date: Wed, 16 Nov 2022 13:06:43 -0500 Subject: [PATCH 1/3] apply new black formatting to all files --- .github/workflows/style_check.yml | 11 +- .pre-commit-config.yaml | 26 +- docker/README.md | 24 +- .../Dockerfile | 14 +- .../Dockerfile | 17 +- .../images/k2-gt-v1.9-blueviolet.svg | 2 +- .../images/python-gt-v3.6-blue.svg | 2 +- .../images/torch-gt-v1.6.0-green.svg | 2 +- docs/source/recipes/aishell/index.rst | 1 - docs/source/recipes/timit/index.rst | 1 - docs/source/recipes/timit/tdnn_ligru_ctc.rst | 28 +- docs/source/recipes/timit/tdnn_lstm_ctc.rst | 24 +- .../local/compute_fbank_aidatatang_200zh.py | 8 +- .../ASR/local/prepare_char.py | 8 +- .../ASR/local/prepare_lang.py | 4 +- .../ASR/local/test_prepare_lang.py | 4 +- egs/aidatatang_200zh/ASR/local/text2token.py | 21 +- egs/aidatatang_200zh/ASR/prepare.sh | 3 +- .../asr_datamodule.py | 110 +- .../pruned_transducer_stateless2/decode.py | 50 +- .../pruned_transducer_stateless2/export.py | 20 +- .../pretrained.py | 41 +- .../ASR/pruned_transducer_stateless2/train.py | 50 +- egs/aishell/ASR/conformer_ctc/conformer.py | 70 +- egs/aishell/ASR/conformer_ctc/decode.py | 29 +- egs/aishell/ASR/conformer_ctc/export.py | 17 +- egs/aishell/ASR/conformer_ctc/pretrained.py | 39 +- egs/aishell/ASR/conformer_ctc/subsampling.py | 16 +- .../ASR/conformer_ctc/test_subsampling.py | 3 +- egs/aishell/ASR/conformer_ctc/train.py | 12 +- egs/aishell/ASR/conformer_ctc/transformer.py | 44 +- egs/aishell/ASR/conformer_mmi/conformer.py | 70 +- egs/aishell/ASR/conformer_mmi/decode.py | 33 +- egs/aishell/ASR/conformer_mmi/subsampling.py | 16 +- egs/aishell/ASR/conformer_mmi/train.py | 8 +- egs/aishell/ASR/conformer_mmi/transformer.py | 44 +- .../local/compute_fbank_aidatatang_200zh.py | 8 +- .../ASR/local/compute_fbank_aishell.py | 8 +- egs/aishell/ASR/local/prepare_char.py | 8 +- egs/aishell/ASR/local/prepare_lang.py | 4 +- egs/aishell/ASR/local/test_prepare_lang.py | 4 +- .../pruned_transducer_stateless2/decode.py | 50 +- .../pruned_transducer_stateless2/export.py | 31 +- .../pretrained.py | 50 +- .../ASR/pruned_transducer_stateless2/train.py | 64 +- .../pruned_transducer_stateless3/decode.py | 73 +- .../pruned_transducer_stateless3/export.py | 54 +- .../ASR/pruned_transducer_stateless3/model.py | 8 +- .../pretrained.py | 50 +- .../ASR/pruned_transducer_stateless3/train.py | 79 +- .../ASR/tdnn_lstm_ctc/asr_datamodule.py | 118 +- egs/aishell/ASR/tdnn_lstm_ctc/decode.py | 33 +- egs/aishell/ASR/tdnn_lstm_ctc/model.py | 5 +- egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py | 37 +- egs/aishell/ASR/tdnn_lstm_ctc/train.py | 7 +- .../ASR/transducer_stateless/beam_search.py | 22 +- .../ASR/transducer_stateless/conformer.py | 70 +- .../ASR/transducer_stateless/decode.py | 39 +- .../ASR/transducer_stateless/decoder.py | 4 +- .../ASR/transducer_stateless/export.py | 20 +- egs/aishell/ASR/transducer_stateless/model.py | 4 +- .../ASR/transducer_stateless/pretrained.py | 36 +- egs/aishell/ASR/transducer_stateless/train.py | 15 +- .../ASR/transducer_stateless/transformer.py | 4 +- .../asr_datamodule.py | 85 +- .../transducer_stateless_modified-2/decode.py | 46 +- .../transducer_stateless_modified-2/export.py | 20 +- .../pretrained.py | 50 +- .../transducer_stateless_modified-2/train.py | 22 +- .../transducer_stateless_modified/decode.py | 46 +- .../transducer_stateless_modified/export.py | 20 +- .../pretrained.py | 50 +- .../transducer_stateless_modified/train.py | 15 +- egs/aishell2/ASR/local/__init__.py | 0 .../ASR/local/compute_fbank_aishell2.py | 8 +- .../pruned_transducer_stateless5/__init__.py | 0 .../asr_datamodule.py | 114 +- .../pruned_transducer_stateless5/decode.py | 67 +- .../pruned_transducer_stateless5/export.py | 47 +- .../pretrained.py | 40 +- .../ASR/pruned_transducer_stateless5/train.py | 67 +- .../ASR/local/compute_fbank_aishell4.py | 8 +- egs/aishell4/ASR/local/prepare_char.py | 8 +- egs/aishell4/ASR/local/prepare_lang.py | 4 +- egs/aishell4/ASR/local/test_prepare_lang.py | 4 +- egs/aishell4/ASR/local/text2token.py | 21 +- .../asr_datamodule.py | 110 +- .../pruned_transducer_stateless5/decode.py | 69 +- .../pruned_transducer_stateless5/export.py | 47 +- .../pretrained.py | 45 +- .../ASR/pruned_transducer_stateless5/train.py | 59 +- .../ASR/local/compute_fbank_alimeeting.py | 8 +- egs/alimeeting/ASR/local/prepare_char.py | 8 +- egs/alimeeting/ASR/local/prepare_lang.py | 4 +- egs/alimeeting/ASR/local/test_prepare_lang.py | 4 +- egs/alimeeting/ASR/local/text2segments.py | 2 +- egs/alimeeting/ASR/local/text2token.py | 21 +- .../asr_datamodule.py | 110 +- .../pruned_transducer_stateless2/decode.py | 60 +- .../pruned_transducer_stateless2/export.py | 20 +- .../pretrained.py | 41 +- .../ASR/pruned_transducer_stateless2/train.py | 50 +- egs/csj/ASR/.gitignore | 2 +- egs/csj/ASR/local/compute_fbank_csj.py | 38 +- egs/csj/ASR/local/compute_fbank_musan.py | 17 +- egs/csj/ASR/local/conf/disfluent.ini | 55 +- egs/csj/ASR/local/conf/fluent.ini | 55 +- egs/csj/ASR/local/conf/number.ini | 55 +- egs/csj/ASR/local/conf/symbol.ini | 55 +- .../ASR/local/display_manifest_statistics.py | 4 +- egs/csj/ASR/local/prepare_lang_char.py | 17 +- egs/csj/ASR/local/validate_manifest.py | 7 +- .../ASR/conformer_ctc/asr_datamodule.py | 117 +- egs/gigaspeech/ASR/conformer_ctc/conformer.py | 66 +- egs/gigaspeech/ASR/conformer_ctc/decode.py | 29 +- .../ASR/conformer_ctc/gigaspeech_scoring.py | 3 +- .../ASR/conformer_ctc/label_smoothing.py | 7 +- .../ASR/conformer_ctc/subsampling.py | 16 +- egs/gigaspeech/ASR/conformer_ctc/train.py | 12 +- .../ASR/conformer_ctc/transformer.py | 49 +- .../compute_fbank_gigaspeech_dev_test.py | 4 +- .../local/compute_fbank_gigaspeech_splits.py | 10 +- .../ASR/local/preprocess_gigaspeech.py | 10 +- .../asr_datamodule.py | 117 +- .../pruned_transducer_stateless2/decode.py | 42 +- .../pruned_transducer_stateless2/export.py | 24 +- .../ASR/pruned_transducer_stateless2/train.py | 48 +- egs/librispeech/ASR/conformer_ctc/ali.py | 25 +- .../ASR/conformer_ctc/conformer.py | 66 +- egs/librispeech/ASR/conformer_ctc/decode.py | 29 +- egs/librispeech/ASR/conformer_ctc/export.py | 17 +- .../ASR/conformer_ctc/label_smoothing.py | 7 +- .../ASR/conformer_ctc/pretrained.py | 33 +- .../ASR/conformer_ctc/subsampling.py | 16 +- egs/librispeech/ASR/conformer_ctc/train.py | 22 +- .../ASR/conformer_ctc/transformer.py | 49 +- .../ASR/conformer_ctc2/attention.py | 19 +- .../ASR/conformer_ctc2/conformer.py | 65 +- egs/librispeech/ASR/conformer_ctc2/decode.py | 56 +- egs/librispeech/ASR/conformer_ctc2/export.py | 49 +- egs/librispeech/ASR/conformer_ctc2/train.py | 39 +- .../ASR/conformer_ctc2/transformer.py | 50 +- .../ASR/conformer_mmi/conformer.py | 70 +- egs/librispeech/ASR/conformer_mmi/decode.py | 29 +- .../ASR/conformer_mmi/subsampling.py | 16 +- .../ASR/conformer_mmi/test_subsampling.py | 3 +- .../ASR/conformer_mmi/test_transformer.py | 9 +- .../ASR/conformer_mmi/train-with-attention.py | 27 +- egs/librispeech/ASR/conformer_mmi/train.py | 27 +- .../ASR/conformer_mmi/transformer.py | 28 +- .../decode.py | 69 +- .../emformer.py | 119 +- .../export.py | 47 +- .../stream.py | 8 +- .../streaming_decode.py | 75 +- .../train.py | 56 +- .../decode.py | 69 +- .../emformer.py | 108 +- .../export.py | 47 +- .../streaming_decode.py | 75 +- .../train.py | 56 +- .../ASR/local/add_alignment_librispeech.py | 12 +- egs/librispeech/ASR/local/compile_hlg.py | 4 +- egs/librispeech/ASR/local/compile_lg.py | 4 +- .../compute_fbank_gigaspeech_dev_test.py | 4 +- .../local/compute_fbank_gigaspeech_splits.py | 10 +- .../ASR/local/compute_fbank_librispeech.py | 8 +- .../ASR/local/compute_fbank_musan.py | 8 +- .../convert_transcript_words_to_tokens.py | 16 +- egs/librispeech/ASR/local/download_lm.py | 4 +- egs/librispeech/ASR/local/filter_cuts.py | 10 +- .../ASR/local/generate_unique_lexicon.py | 4 +- egs/librispeech/ASR/local/prepare_lang_bpe.py | 4 +- .../ASR/local/prepare_lm_training_data.py | 11 +- .../ASR/local/preprocess_gigaspeech.py | 4 +- .../ASR/local/test_prepare_lang.py | 4 +- .../ASR/local/validate_manifest.py | 7 +- .../ASR/lstm_transducer_stateless/decode.py | 818 ------------ .../ASR/lstm_transducer_stateless/export.py | 388 ------ .../jit_pretrained.py | 322 ----- .../ASR/lstm_transducer_stateless/lstm.py | 871 ------------- .../ASR/lstm_transducer_stateless/model.py | 210 --- .../lstm_transducer_stateless/pretrained.py | 352 ----- .../ASR/lstm_transducer_stateless/stream.py | 148 --- .../streaming_decode.py | 968 -------------- .../ASR/lstm_transducer_stateless/train.py | 1157 ----------------- .../ASR/lstm_transducer_stateless2/decode.py | 67 +- .../ASR/lstm_transducer_stateless2/export.py | 59 +- .../jit_pretrained.py | 21 +- .../ASR/lstm_transducer_stateless2/model.py | 8 +- .../lstm_transducer_stateless2/ncnn-decode.py | 15 +- .../lstm_transducer_stateless2/pretrained.py | 40 +- .../streaming-ncnn-decode.py | 27 +- .../streaming-onnx-decode.py | 45 +- .../ASR/lstm_transducer_stateless2/train.py | 68 +- .../ASR/lstm_transducer_stateless3/decode.py | 79 +- .../ASR/lstm_transducer_stateless3/export.py | 47 +- .../jit_pretrained.py | 21 +- .../ASR/lstm_transducer_stateless3/lstm.py | 14 +- .../lstm_transducer_stateless3/pretrained.py | 40 +- .../streaming_decode.py | 74 +- .../ASR/lstm_transducer_stateless3/train.py | 66 +- .../ASR/pruned2_knowledge/asr_datamodule.py | 125 +- .../ASR/pruned2_knowledge/beam_search.py | 18 +- .../ASR/pruned2_knowledge/conformer.py | 90 +- .../ASR/pruned2_knowledge/decode.py | 44 +- .../ASR/pruned2_knowledge/decoder.py | 4 +- .../ASR/pruned2_knowledge/decoder2.py | 84 +- .../ASR/pruned2_knowledge/export.py | 20 +- .../ASR/pruned2_knowledge/joiner.py | 4 +- .../ASR/pruned2_knowledge/model.py | 8 +- .../ASR/pruned2_knowledge/optim.py | 35 +- .../ASR/pruned2_knowledge/sampling.py | 184 +-- .../ASR/pruned2_knowledge/scaling.py | 51 +- .../ASR/pruned2_knowledge/scaling_tmp.py | 355 +++-- .../ASR/pruned2_knowledge/train.py | 50 +- .../pruned_stateless_emformer_rnnt2/decode.py | 69 +- .../emformer.py | 8 +- .../pruned_stateless_emformer_rnnt2/export.py | 47 +- .../pruned_stateless_emformer_rnnt2/model.py | 4 +- .../pruned_stateless_emformer_rnnt2/train.py | 44 +- .../beam_search.py | 26 +- .../ASR/pruned_transducer_stateless/decode.py | 44 +- .../decode_stream.py | 19 +- .../pruned_transducer_stateless/decoder.py | 4 +- .../ASR/pruned_transducer_stateless/export.py | 20 +- .../ASR/pruned_transducer_stateless/model.py | 4 +- .../pruned_transducer_stateless/pretrained.py | 36 +- .../streaming_beam_search.py | 8 +- .../streaming_decode.py | 39 +- .../ASR/pruned_transducer_stateless/train.py | 46 +- .../beam_search.py | 51 +- .../pruned_transducer_stateless2/conformer.py | 97 +- .../pruned_transducer_stateless2/decode.py | 50 +- .../pruned_transducer_stateless2/decoder.py | 8 +- .../pruned_transducer_stateless2/export.py | 24 +- .../pruned_transducer_stateless2/joiner.py | 4 +- .../ASR/pruned_transducer_stateless2/model.py | 8 +- .../ASR/pruned_transducer_stateless2/optim.py | 35 +- .../pretrained.py | 36 +- .../pruned_transducer_stateless2/scaling.py | 56 +- .../streaming_beam_search.py | 12 +- .../streaming_decode.py | 39 +- .../ASR/pruned_transducer_stateless2/train.py | 58 +- .../asr_datamodule.py | 85 +- .../decode-giga.py | 54 +- .../pruned_transducer_stateless3/decode.py | 74 +- .../pruned_transducer_stateless3/export.py | 32 +- .../gigaspeech.py | 8 +- .../jit_pretrained.py | 21 +- .../ASR/pruned_transducer_stateless3/model.py | 8 +- .../onnx_check.py | 24 +- .../onnx_pretrained.py | 27 +- .../pretrained.py | 36 +- .../scaling_converter.py | 10 +- .../streaming_decode.py | 39 +- .../pruned_transducer_stateless3/test_onnx.py | 24 +- .../ASR/pruned_transducer_stateless3/train.py | 65 +- .../pruned_transducer_stateless4/decode.py | 79 +- .../pruned_transducer_stateless4/export.py | 47 +- .../streaming_decode.py | 62 +- .../ASR/pruned_transducer_stateless4/train.py | 61 +- .../pruned_transducer_stateless5/conformer.py | 118 +- .../pruned_transducer_stateless5/decode.py | 67 +- .../pruned_transducer_stateless5/export.py | 47 +- .../pretrained.py | 40 +- .../streaming_decode.py | 62 +- .../ASR/pruned_transducer_stateless5/train.py | 66 +- .../pruned_transducer_stateless6/conformer.py | 67 +- .../pruned_transducer_stateless6/decode.py | 69 +- .../pruned_transducer_stateless6/export.py | 24 +- .../extract_codebook_index.py | 3 +- .../hubert_decode.py | 17 +- .../hubert_xlarge.py | 22 +- .../ASR/pruned_transducer_stateless6/model.py | 12 +- .../ASR/pruned_transducer_stateless6/train.py | 65 +- .../pruned_transducer_stateless6/vq_utils.py | 31 +- .../pruned_transducer_stateless7/decode.py | 67 +- .../pruned_transducer_stateless7/decoder.py | 6 +- .../pruned_transducer_stateless7/export.py | 47 +- .../jit_pretrained.py | 21 +- .../pruned_transducer_stateless7/joiner.py | 4 +- .../ASR/pruned_transducer_stateless7/model.py | 16 +- .../ASR/pruned_transducer_stateless7/optim.py | 439 ++++--- .../pretrained.py | 40 +- .../pruned_transducer_stateless7/scaling.py | 487 +++---- .../scaling_converter.py | 12 +- .../ASR/pruned_transducer_stateless7/train.py | 88 +- .../pruned_transducer_stateless7/zipformer.py | 660 +++++----- .../pruned_transducer_stateless8/decode.py | 67 +- .../pruned_transducer_stateless8/export.py | 47 +- .../jit_pretrained.py | 21 +- .../ASR/pruned_transducer_stateless8/model.py | 4 +- .../pretrained.py | 40 +- .../ASR/pruned_transducer_stateless8/train.py | 99 +- .../ASR/streaming_conformer_ctc/README.md | 16 +- .../ASR/streaming_conformer_ctc/conformer.py | 116 +- .../streaming_decode.py | 68 +- .../ASR/streaming_conformer_ctc/train.py | 16 +- .../streaming_conformer_ctc/transformer.py | 40 +- .../ASR/tdnn_lstm_ctc/asr_datamodule.py | 113 +- egs/librispeech/ASR/tdnn_lstm_ctc/decode.py | 29 +- egs/librispeech/ASR/tdnn_lstm_ctc/model.py | 5 +- .../ASR/tdnn_lstm_ctc/pretrained.py | 43 +- egs/librispeech/ASR/tdnn_lstm_ctc/train.py | 8 +- egs/librispeech/ASR/transducer/beam_search.py | 14 +- egs/librispeech/ASR/transducer/decode.py | 28 +- egs/librispeech/ASR/transducer/export.py | 17 +- egs/librispeech/ASR/transducer/pretrained.py | 33 +- egs/librispeech/ASR/transducer/rnn.py | 24 +- egs/librispeech/ASR/transducer/test_rnn.py | 16 +- egs/librispeech/ASR/transducer/train.py | 12 +- .../ASR/transducer_lstm/beam_search.py | 14 +- egs/librispeech/ASR/transducer_lstm/decode.py | 28 +- .../ASR/transducer_lstm/encoder.py | 4 +- egs/librispeech/ASR/transducer_lstm/train.py | 12 +- .../ASR/transducer_stateless/alignment.py | 4 +- .../ASR/transducer_stateless/beam_search.py | 28 +- .../ASR/transducer_stateless/compute_ali.py | 24 +- .../ASR/transducer_stateless/conformer.py | 107 +- .../ASR/transducer_stateless/decode.py | 42 +- .../ASR/transducer_stateless/decoder.py | 4 +- .../ASR/transducer_stateless/export.py | 20 +- .../ASR/transducer_stateless/joiner.py | 8 +- .../ASR/transducer_stateless/pretrained.py | 36 +- .../transducer_stateless/test_compute_ali.py | 11 +- .../transducer_stateless/test_conformer.py | 4 +- .../ASR/transducer_stateless/train.py | 23 +- .../ASR/transducer_stateless/transformer.py | 4 +- .../ASR/transducer_stateless2/decode.py | 42 +- .../ASR/transducer_stateless2/export.py | 20 +- .../ASR/transducer_stateless2/pretrained.py | 36 +- .../ASR/transducer_stateless2/train.py | 23 +- .../decode.py | 42 +- .../export.py | 20 +- .../pretrained.py | 36 +- .../test_asr_datamodule.py | 4 +- .../train.py | 22 +- egs/ptb/LM/local/sort_lm_training_data.py | 4 +- .../LM/local/test_prepare_lm_training_data.py | 4 +- .../ASR/local/compute_fbank_musan.py | 8 +- .../ASR/local/compute_fbank_spgispeech.py | 14 +- egs/spgispeech/ASR/local/prepare_splits.py | 8 +- .../asr_datamodule.py | 100 +- .../pruned_transducer_stateless2/decode.py | 66 +- .../pruned_transducer_stateless2/export.py | 26 +- .../ASR/pruned_transducer_stateless2/train.py | 51 +- .../ASR/local/compute_fbank_tal_csasr.py | 8 +- egs/tal_csasr/ASR/local/prepare_char.py | 4 +- egs/tal_csasr/ASR/local/prepare_lang.py | 4 +- egs/tal_csasr/ASR/local/test_prepare_lang.py | 4 +- egs/tal_csasr/ASR/local/text2token.py | 21 +- .../asr_datamodule.py | 110 +- .../pruned_transducer_stateless5/decode.py | 77 +- .../pruned_transducer_stateless5/export.py | 47 +- .../pretrained.py | 40 +- .../ASR/pruned_transducer_stateless5/train.py | 59 +- .../ASR/local/compute_fbank_tedlium.py | 8 +- .../convert_transcript_words_to_bpe_ids.py | 4 +- egs/tedlium3/ASR/local/prepare_lexicon.py | 11 +- egs/tedlium3/ASR/local/prepare_transcripts.py | 11 +- .../ASR/pruned_transducer_stateless/decode.py | 38 +- .../ASR/pruned_transducer_stateless/export.py | 20 +- .../pruned_transducer_stateless/pretrained.py | 41 +- .../ASR/pruned_transducer_stateless/train.py | 35 +- .../transducer_stateless/asr_datamodule.py | 118 +- .../ASR/transducer_stateless/beam_search.py | 30 +- .../ASR/transducer_stateless/decode.py | 31 +- .../ASR/transducer_stateless/decoder.py | 4 +- .../ASR/transducer_stateless/export.py | 20 +- .../ASR/transducer_stateless/pretrained.py | 36 +- .../ASR/transducer_stateless/train.py | 11 +- egs/timit/ASR/RESULTS.md | 2 +- egs/timit/ASR/local/compile_hlg.py | 4 +- egs/timit/ASR/local/compute_fbank_timit.py | 8 +- egs/timit/ASR/local/prepare_lexicon.py | 8 +- egs/timit/ASR/prepare.sh | 4 +- egs/timit/ASR/tdnn_ligru_ctc/decode.py | 29 +- egs/timit/ASR/tdnn_ligru_ctc/model.py | 12 +- egs/timit/ASR/tdnn_ligru_ctc/pretrained.py | 43 +- egs/timit/ASR/tdnn_ligru_ctc/train.py | 4 +- egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py | 104 +- egs/timit/ASR/tdnn_lstm_ctc/decode.py | 29 +- egs/timit/ASR/tdnn_lstm_ctc/model.py | 5 +- egs/timit/ASR/tdnn_lstm_ctc/pretrained.py | 43 +- egs/timit/ASR/tdnn_lstm_ctc/train.py | 4 +- .../compute_fbank_wenetspeech_dev_test.py | 11 +- .../local/compute_fbank_wenetspeech_splits.py | 10 +- egs/wenetspeech/ASR/local/prepare_char.py | 8 +- .../ASR/local/preprocess_wenetspeech.py | 6 +- egs/wenetspeech/ASR/local/text2token.py | 21 +- egs/wenetspeech/ASR/prepare.sh | 2 +- .../asr_datamodule.py | 121 +- .../pruned_transducer_stateless2/decode.py | 64 +- .../pruned_transducer_stateless2/export.py | 28 +- .../jit_pretrained.py | 21 +- .../onnx_check.py | 24 +- .../onnx_pretrained.py | 27 +- .../pretrained.py | 41 +- .../ASR/pruned_transducer_stateless2/train.py | 50 +- .../pruned_transducer_stateless5/conformer.py | 97 +- .../pruned_transducer_stateless5/decode.py | 75 +- .../decode_stream.py | 19 +- .../pruned_transducer_stateless5/export.py | 20 +- .../pretrained.py | 41 +- .../streaming_beam_search.py | 8 +- .../streaming_decode.py | 62 +- .../ASR/pruned_transducer_stateless5/train.py | 67 +- egs/yesno/ASR/local/compile_hlg.py | 4 +- egs/yesno/ASR/local/compute_fbank_yesno.py | 12 +- egs/yesno/ASR/tdnn/asr_datamodule.py | 74 +- egs/yesno/ASR/tdnn/decode.py | 29 +- egs/yesno/ASR/tdnn/pretrained.py | 37 +- egs/yesno/ASR/tdnn/train.py | 4 +- egs/yesno/ASR/transducer/decode.py | 25 +- egs/yesno/ASR/transducer/train.py | 4 +- icefall/char_graph_compiler.py | 8 +- icefall/checkpoint.py | 12 +- icefall/decode.py | 36 +- icefall/diagnostics.py | 80 +- icefall/dist.py | 4 +- icefall/env.py | 4 +- icefall/graph_compiler.py | 4 +- icefall/hooks.py | 19 +- icefall/lexicon.py | 16 +- icefall/mmi.py | 29 +- icefall/mmi_graph_compiler.py | 8 +- icefall/rnn_lm/compute_perplexity.py | 15 +- icefall/rnn_lm/dataset.py | 8 +- icefall/rnn_lm/export.py | 17 +- icefall/rnn_lm/model.py | 28 +- icefall/rnn_lm/train.py | 11 +- icefall/shared/make_kn_lm.py | 184 ++- icefall/utils.py | 64 +- pyproject.toml | 2 +- setup.py | 3 +- test/test_checkpoint.py | 6 +- test/test_decode.py | 1 + test/test_graph_compiler.py | 4 +- test/test_utils.py | 4 +- 440 files changed, 6789 insertions(+), 14532 deletions(-) mode change 100755 => 100644 egs/aishell2/ASR/local/__init__.py mode change 100755 => 100644 egs/aishell2/ASR/pruned_transducer_stateless5/__init__.py mode change 100755 => 100644 egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py mode change 100755 => 100644 egs/librispeech/ASR/lstm_transducer_stateless/decode.py mode change 100755 => 100644 egs/librispeech/ASR/lstm_transducer_stateless/export.py mode change 100755 => 100644 egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py mode change 100755 => 100644 egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py mode change 100755 => 100644 egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py mode change 100755 => 100644 egs/librispeech/ASR/lstm_transducer_stateless/train.py diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml index 90459bc1c..45d261ccc 100644 --- a/.github/workflows/style_check.yml +++ b/.github/workflows/style_check.yml @@ -45,17 +45,18 @@ jobs: - name: Install Python dependencies run: | - python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2 click==8.0.4 - # See https://github.com/psf/black/issues/2964 - # The version of click should be selected from 8.0.0, 8.0.1, 8.0.2, 8.0.3, and 8.0.4 + python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0 + # Click issue fixed in https://github.com/psf/black/pull/2966 - name: Run flake8 shell: bash working-directory: ${{github.workspace}} run: | # stop the build if there are Python syntax errors or undefined names - flake8 . --count --show-source --statistics - flake8 . + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 \ + --statistics --extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503 - name: Run black shell: bash diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 446ba0fe7..e2055801b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,26 +1,38 @@ repos: - repo: https://github.com/psf/black - rev: 21.6b0 + rev: 22.3.0 hooks: - id: black - args: [--line-length=80] + args: ["--line-length=88"] additional_dependencies: ['click==8.0.1'] exclude: icefall\/__init__\.py - repo: https://github.com/PyCQA/flake8 - rev: 3.9.2 + rev: 5.0.4 hooks: - id: flake8 - args: [--max-line-length=80] + args: ["--max-line-length=88", "--extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503"] + + # What are we ignoring here? + # E203: whitespace before ':' + # E266: too many leading '#' for block comment + # E501: line too long + # F401: module imported but unused + # E402: module level import not at top of file + # F403: 'from module import *' used; unable to detect undefined names + # F841: local variable is assigned to but never used + # W503: line break before binary operator + # In addition, the default ignore list is: + # E121,E123,E126,E226,E24,E704,W503,W504 - repo: https://github.com/pycqa/isort - rev: 5.9.2 + rev: 5.10.1 hooks: - id: isort - args: [--profile=black, --line-length=80] + args: ["--profile=black"] - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.0.1 + rev: v4.2.0 hooks: - id: check-executables-have-shebangs - id: end-of-file-fixer diff --git a/docker/README.md b/docker/README.md index 6f2314e96..c14b9bf75 100644 --- a/docker/README.md +++ b/docker/README.md @@ -2,7 +2,7 @@ 2 sets of configuration are provided - (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8, and (b) Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8. -If your NVIDIA driver supports CUDA Version: 11.3, please go for case (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8. +If your NVIDIA driver supports CUDA Version: 11.3, please go for case (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8. Otherwise, since the older PyTorch images are not updated with the [apt-key rotation by NVIDIA](https://developer.nvidia.com/blog/updating-the-cuda-linux-gpg-repository-key), you have to go for case (b) Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8. Ensure that your NVDIA driver supports at least CUDA 11.0. @@ -10,7 +10,7 @@ You can check the highest CUDA version within your NVIDIA driver's support with ```bash $ nvidia-smi -Tue Sep 20 00:26:13 2022 +Tue Sep 20 00:26:13 2022 +-----------------------------------------------------------------------------+ | NVIDIA-SMI 450.119.03 Driver Version: 450.119.03 CUDA Version: 11.0 | |-------------------------------+----------------------+----------------------+ @@ -26,7 +26,7 @@ Tue Sep 20 00:26:13 2022 | 41% 30C P8 11W / 280W | 6MiB / 24220MiB | 0% Default | | | | N/A | +-------------------------------+----------------------+----------------------+ - + +-----------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | @@ -40,15 +40,15 @@ Tue Sep 20 00:26:13 2022 ``` ## Building images locally -If your environment requires a proxy to access the Internet, remember to add those information into the Dockerfile directly. -For most cases, you can uncomment these lines in the Dockerfile and add in your proxy details. +If your environment requires a proxy to access the Internet, remember to add those information into the Dockerfile directly. +For most cases, you can uncomment these lines in the Dockerfile and add in your proxy details. ```dockerfile ENV http_proxy=http://aaa.bb.cc.net:8080 \ https_proxy=http://aaa.bb.cc.net:8080 ``` -Then, proceed with these commands. +Then, proceed with these commands. ### If you are case (a), i.e. your NVIDIA driver supports CUDA version >= 11.3: @@ -72,11 +72,11 @@ docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all icefall ``` ### Tips: -1. Since your data and models most probably won't be in the docker, you must use the -v flag to access the host machine. Do this by specifying `-v {/path/in/host/machine}:{/path/in/docker}`. +1. Since your data and models most probably won't be in the docker, you must use the -v flag to access the host machine. Do this by specifying `-v {/path/in/host/machine}:{/path/in/docker}`. 2. Also, if your environment requires a proxy, this would be a good time to add it in too: `-e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080`. -Overall, your docker run command should look like this. +Overall, your docker run command should look like this. ```bash docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all -v {/path/in/host/machine}:{/path/in/docker} -e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080 icefall/pytorch1.12.1 @@ -86,9 +86,9 @@ You can explore more docker run options [here](https://docs.docker.com/engine/re ### Linking to icefall in your host machine -If you already have icefall downloaded onto your host machine, you can use that repository instead so that changes in your code are visible inside and outside of the container. +If you already have icefall downloaded onto your host machine, you can use that repository instead so that changes in your code are visible inside and outside of the container. -Note: Remember to set the -v flag above during the first run of the container, as that is the only way for your container to access your host machine. +Note: Remember to set the -v flag above during the first run of the container, as that is the only way for your container to access your host machine. Warning: Check that the icefall in your host machine is visible from within your container before proceeding to the commands below. Use these commands once you are inside the container. @@ -103,7 +103,7 @@ ln -s {/path/in/docker/to/icefall} /workspace/icefall docker exec -it icefall /bin/bash ``` -## Restarting a killed container that has been run before. +## Restarting a killed container that has been run before. ```bash docker start -ai icefall ``` @@ -111,4 +111,4 @@ docker start -ai icefall ## Sample usage of the CPU based images: ```bash docker run -it icefall /bin/bash -``` +``` diff --git a/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile b/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile index 3637d2f11..ff9e40604 100644 --- a/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile +++ b/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile @@ -1,7 +1,7 @@ FROM pytorch/pytorch:1.12.1-cuda11.3-cudnn8-devel # ENV http_proxy=http://aaa.bbb.cc.net:8080 \ -# https_proxy=http://aaa.bbb.cc.net:8080 +# https_proxy=http://aaa.bbb.cc.net:8080 # install normal source RUN apt-get update && \ @@ -38,10 +38,10 @@ RUN wget -P /opt https://cmake.org/files/v3.18/cmake-3.18.0.tar.gz && \ rm -rf cmake-3.18.0.tar.gz && \ find /opt/cmake-3.18.0 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \ cd - - -# flac + +# flac RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz && \ - cd /opt && \ + cd /opt && \ xz -d flac-1.3.2.tar.xz && \ tar -xvf flac-1.3.2.tar && \ cd flac-1.3.2 && \ @@ -49,11 +49,11 @@ RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz && make && make install && \ rm -rf flac-1.3.2.tar && \ find /opt/flac-1.3.2 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \ - cd - + cd - RUN conda install -y -c pytorch torchaudio=0.12 && \ pip install graphviz - + #install k2 from source RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \ @@ -68,7 +68,7 @@ RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ cd /workspace/icefall && \ pip install -r requirements.txt -RUN pip install kaldifeat +RUN pip install kaldifeat ENV PYTHONPATH /workspace/icefall:$PYTHONPATH WORKDIR /workspace/icefall diff --git a/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile b/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile index 17a8215f9..5c7423fa5 100644 --- a/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile +++ b/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile @@ -1,12 +1,12 @@ FROM pytorch/pytorch:1.7.1-cuda11.0-cudnn8-devel # ENV http_proxy=http://aaa.bbb.cc.net:8080 \ -# https_proxy=http://aaa.bbb.cc.net:8080 +# https_proxy=http://aaa.bbb.cc.net:8080 RUN rm /etc/apt/sources.list.d/cuda.list && \ rm /etc/apt/sources.list.d/nvidia-ml.list && \ apt-key del 7fa2af80 - + # install normal source RUN apt-get update && \ apt-get install -y --no-install-recommends \ @@ -36,7 +36,7 @@ RUN curl -fsSL https://developer.download.nvidia.com/compute/cuda/repos/ubuntu18 curl -fsSL https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub | apt-key add - && \ echo "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/cuda.list && \ echo "deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/nvidia-ml.list && \ - rm -rf /var/lib/apt/lists/* && \ + rm -rf /var/lib/apt/lists/* && \ mv /opt/conda/lib/libcufft.so.10 /opt/libcufft.so.10.bak && \ mv /opt/conda/lib/libcurand.so.10 /opt/libcurand.so.10.bak && \ mv /opt/conda/lib/libcublas.so.11 /opt/libcublas.so.11.bak && \ @@ -56,10 +56,10 @@ RUN wget -P /opt https://cmake.org/files/v3.18/cmake-3.18.0.tar.gz && \ rm -rf cmake-3.18.0.tar.gz && \ find /opt/cmake-3.18.0 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \ cd - - -# flac + +# flac RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz && \ - cd /opt && \ + cd /opt && \ xz -d flac-1.3.2.tar.xz && \ tar -xvf flac-1.3.2.tar && \ cd flac-1.3.2 && \ @@ -67,7 +67,7 @@ RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz && make && make install && \ rm -rf flac-1.3.2.tar && \ find /opt/flac-1.3.2 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \ - cd - + cd - RUN conda install -y -c pytorch torchaudio=0.7.1 && \ pip install graphviz @@ -79,7 +79,7 @@ RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \ cd - # install lhotse -RUN pip install git+https://github.com/lhotse-speech/lhotse +RUN pip install git+https://github.com/lhotse-speech/lhotse RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ cd /workspace/icefall && \ @@ -88,4 +88,3 @@ RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ ENV PYTHONPATH /workspace/icefall:$PYTHONPATH WORKDIR /workspace/icefall - diff --git a/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg b/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg index 534b2e534..3019ff03d 100644 --- a/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg +++ b/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg @@ -1 +1 @@ -k2: >= v1.9k2>= v1.9 \ No newline at end of file +k2: >= v1.9k2>= v1.9 diff --git a/docs/source/installation/images/python-gt-v3.6-blue.svg b/docs/source/installation/images/python-gt-v3.6-blue.svg index 4254dc58a..df677ad09 100644 --- a/docs/source/installation/images/python-gt-v3.6-blue.svg +++ b/docs/source/installation/images/python-gt-v3.6-blue.svg @@ -1 +1 @@ -python: >= 3.6python>= 3.6 \ No newline at end of file +python: >= 3.6python>= 3.6 diff --git a/docs/source/installation/images/torch-gt-v1.6.0-green.svg b/docs/source/installation/images/torch-gt-v1.6.0-green.svg index d3ece9a17..d7007d742 100644 --- a/docs/source/installation/images/torch-gt-v1.6.0-green.svg +++ b/docs/source/installation/images/torch-gt-v1.6.0-green.svg @@ -1 +1 @@ -torch: >= 1.6.0torch>= 1.6.0 \ No newline at end of file +torch: >= 1.6.0torch>= 1.6.0 diff --git a/docs/source/recipes/aishell/index.rst b/docs/source/recipes/aishell/index.rst index d072d6e9c..b77d59bca 100644 --- a/docs/source/recipes/aishell/index.rst +++ b/docs/source/recipes/aishell/index.rst @@ -19,4 +19,3 @@ It can be downloaded from ``_ tdnn_lstm_ctc conformer_ctc stateless_transducer - diff --git a/docs/source/recipes/timit/index.rst b/docs/source/recipes/timit/index.rst index 17f40cdb7..5ee147be7 100644 --- a/docs/source/recipes/timit/index.rst +++ b/docs/source/recipes/timit/index.rst @@ -6,4 +6,3 @@ TIMIT tdnn_ligru_ctc tdnn_lstm_ctc - diff --git a/docs/source/recipes/timit/tdnn_ligru_ctc.rst b/docs/source/recipes/timit/tdnn_ligru_ctc.rst index 186420ee7..3d7aefe02 100644 --- a/docs/source/recipes/timit/tdnn_ligru_ctc.rst +++ b/docs/source/recipes/timit/tdnn_ligru_ctc.rst @@ -148,10 +148,10 @@ Some commonly used options are: $ ./tdnn_ligru_ctc/decode.py --epoch 25 --avg 17 - uses the average of ``epoch-9.pt``, ``epoch-10.pt``, ``epoch-11.pt``, - ``epoch-12.pt``, ``epoch-13.pt``, ``epoch-14.pt``, ``epoch-15.pt``, - ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``, ``epoch-19.pt``, - ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``, ``epoch-23.pt``, + uses the average of ``epoch-9.pt``, ``epoch-10.pt``, ``epoch-11.pt``, + ``epoch-12.pt``, ``epoch-13.pt``, ``epoch-14.pt``, ``epoch-15.pt``, + ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``, ``epoch-19.pt``, + ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``, ``epoch-23.pt``, ``epoch-24.pt`` and ``epoch-25.pt`` for decoding. @@ -317,13 +317,13 @@ To decode with ``1best`` method, we can use: .. code-block:: bash - ./tdnn_ligru_ctc/pretrained.py + ./tdnn_ligru_ctc/pretrained.py --method 1best - --checkpoint ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/exp/pretrained_average_9_25.pt - --words-file ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/words.txt - --HLG ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/HLG.pt - ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV - ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV + --checkpoint ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/exp/pretrained_average_9_25.pt + --words-file ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/words.txt + --HLG ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/HLG.pt + ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV + ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV The output is: @@ -337,7 +337,7 @@ The output is: 2021-11-08 20:41:38,697 INFO [pretrained.py:210] Reading sound files: ['./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV'] 2021-11-08 20:41:38,704 INFO [pretrained.py:216] Decoding started 2021-11-08 20:41:39,819 INFO [pretrained.py:246] Use HLG decoding - 2021-11-08 20:41:39,829 INFO [pretrained.py:267] + 2021-11-08 20:41:39,829 INFO [pretrained.py:267] ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV: sil dh ih sh uw ah l iy v iy z ih sil p r aa sil k s ih m ey dx ih sil d w uh dx ih w ih s f iy l ih ng w ih th ih n ih m s eh l f sil jh @@ -362,8 +362,8 @@ To decode with ``whole-lattice-rescoring`` methond, you can use --HLG ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/HLG.pt \ --G ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lm/G_4_gram.pt \ --ngram-lm-scale 0.1 \ - ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV - ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV + ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV + ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV The decoding output is: @@ -378,7 +378,7 @@ The decoding output is: 2021-11-08 20:37:54,715 INFO [pretrained.py:210] Reading sound files: ['./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV'] 2021-11-08 20:37:54,720 INFO [pretrained.py:216] Decoding started 2021-11-08 20:37:55,808 INFO [pretrained.py:251] Use HLG decoding + LM rescoring - 2021-11-08 20:37:56,348 INFO [pretrained.py:267] + 2021-11-08 20:37:56,348 INFO [pretrained.py:267] ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV: sil dh ih sh uw ah l iy v iy z ah sil p r aa sil k s ih m ey dx ih sil d w uh dx iy w ih s f iy l iy ng w ih th ih n ih m s eh l f sil jh diff --git a/docs/source/recipes/timit/tdnn_lstm_ctc.rst b/docs/source/recipes/timit/tdnn_lstm_ctc.rst index 6f760a9ce..ee67a6edc 100644 --- a/docs/source/recipes/timit/tdnn_lstm_ctc.rst +++ b/docs/source/recipes/timit/tdnn_lstm_ctc.rst @@ -148,8 +148,8 @@ Some commonly used options are: $ ./tdnn_lstm_ctc/decode.py --epoch 25 --avg 10 - uses the average of ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``, - ``epoch-19.pt``, ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``, + uses the average of ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``, + ``epoch-19.pt``, ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``, ``epoch-23.pt``, ``epoch-24.pt`` and ``epoch-25.pt`` for decoding. @@ -315,13 +315,13 @@ To decode with ``1best`` method, we can use: .. code-block:: bash - ./tdnn_lstm_ctc/pretrained.py + ./tdnn_lstm_ctc/pretrained.py --method 1best - --checkpoint ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/exp/pretrained_average_16_25.pt - --words-file ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/words.txt - --HLG ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt - ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV - ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV + --checkpoint ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/exp/pretrained_average_16_25.pt + --words-file ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/words.txt + --HLG ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt + ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV + ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV The output is: @@ -335,7 +335,7 @@ The output is: 2021-11-08 21:02:53,827 INFO [pretrained.py:210] Reading sound files: ['./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV'] 2021-11-08 21:02:53,831 INFO [pretrained.py:216] Decoding started 2021-11-08 21:02:54,380 INFO [pretrained.py:246] Use HLG decoding - 2021-11-08 21:02:54,387 INFO [pretrained.py:267] + 2021-11-08 21:02:54,387 INFO [pretrained.py:267] ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV: sil dh ih sh uw ah l iy v iy z ih sil p r aa sil k s ih m ey dx ih sil d w uh dx iy w ih s f iy l iy w ih th ih n ih m s eh l f sil jh @@ -360,8 +360,8 @@ To decode with ``whole-lattice-rescoring`` methond, you can use --HLG ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt \ --G ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lm/G_4_gram.pt \ --ngram-lm-scale 0.08 \ - ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV - ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV + ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV + ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV The decoding output is: @@ -376,7 +376,7 @@ The decoding output is: 2021-11-08 20:05:26,978 INFO [pretrained.py:210] Reading sound files: ['./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV'] 2021-11-08 20:05:26,981 INFO [pretrained.py:216] Decoding started 2021-11-08 20:05:27,519 INFO [pretrained.py:251] Use HLG decoding + LM rescoring - 2021-11-08 20:05:27,878 INFO [pretrained.py:267] + 2021-11-08 20:05:27,878 INFO [pretrained.py:267] ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV: sil dh ih sh uw l iy v iy z ih sil p r aa sil k s ah m ey dx ih sil w uh dx iy w ih s f iy l ih ng w ih th ih n ih m s eh l f sil jh diff --git a/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py b/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py index fb2751c0f..387c14acf 100755 --- a/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py +++ b/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py @@ -87,9 +87,7 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80): ) if "train" in partition: cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -116,9 +114,7 @@ def get_args(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/aidatatang_200zh/ASR/local/prepare_char.py b/egs/aidatatang_200zh/ASR/local/prepare_char.py index d9e47d17a..6b440dfb3 100755 --- a/egs/aidatatang_200zh/ASR/local/prepare_char.py +++ b/egs/aidatatang_200zh/ASR/local/prepare_char.py @@ -86,9 +86,7 @@ def lexicon_to_fst_no_sil( cur_state = loop_state word = word2id[word] - pieces = [ - token2id[i] if i in token2id else token2id[""] for i in pieces - ] + pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces] for i in range(len(pieces) - 1): w = word if i == 0 else eps @@ -142,9 +140,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool: return False -def generate_lexicon( - token_sym_table: Dict[str, int], words: List[str] -) -> Lexicon: +def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon: """Generate a lexicon from a word list and token_sym_table. Args: diff --git a/egs/aidatatang_200zh/ASR/local/prepare_lang.py b/egs/aidatatang_200zh/ASR/local/prepare_lang.py index e5ae89ec4..c8cf9b881 100755 --- a/egs/aidatatang_200zh/ASR/local/prepare_lang.py +++ b/egs/aidatatang_200zh/ASR/local/prepare_lang.py @@ -317,9 +317,7 @@ def lexicon_to_fst( def get_args(): parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", type=str, help="The lang dir, data/lang_phone" - ) + parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone") return parser.parse_args() diff --git a/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py b/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py index d4cf62bba..74e025ad7 100755 --- a/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py +++ b/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py @@ -88,9 +88,7 @@ def test_read_lexicon(filename: str): fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.draw("L.pdf", title="L") - fsa_disambig = lexicon_to_fst( - lexicon_disambig, phone2id=phone2id, word2id=word2id - ) + fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id) fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.draw("L_disambig.pdf", title="L_disambig") diff --git a/egs/aidatatang_200zh/ASR/local/text2token.py b/egs/aidatatang_200zh/ASR/local/text2token.py index 71be2a613..2be639b7a 100755 --- a/egs/aidatatang_200zh/ASR/local/text2token.py +++ b/egs/aidatatang_200zh/ASR/local/text2token.py @@ -50,15 +50,15 @@ def get_parser(): "-n", default=1, type=int, - help="number of characters to split, i.e., \ - aabb -> a a b b with -n 1 and aa bb with -n 2", + help=( + "number of characters to split, i.e., aabb -> a a b" + " b with -n 1 and aa bb with -n 2" + ), ) parser.add_argument( "--skip-ncols", "-s", default=0, type=int, help="skip first n columns" ) - parser.add_argument( - "--space", default="", type=str, help="space symbol" - ) + parser.add_argument("--space", default="", type=str, help="space symbol") parser.add_argument( "--non-lang-syms", "-l", @@ -66,9 +66,7 @@ def get_parser(): type=str, help="list of non-linguistic symobles, e.g., etc.", ) - parser.add_argument( - "text", type=str, default=False, nargs="?", help="input text" - ) + parser.add_argument("text", type=str, default=False, nargs="?", help="input text") parser.add_argument( "--trans_type", "-t", @@ -108,8 +106,7 @@ def token2id( if token_type == "lazy_pinyin": text = lazy_pinyin(chars_list) sub_ids = [ - token_table[txt] if txt in token_table else oov_id - for txt in text + token_table[txt] if txt in token_table else oov_id for txt in text ] ids.append(sub_ids) else: # token_type = "pinyin" @@ -135,9 +132,7 @@ def main(): if args.text: f = codecs.open(args.text, encoding="utf-8") else: - f = codecs.getreader("utf-8")( - sys.stdin if is_python2 else sys.stdin.buffer - ) + f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer) sys.stdout = codecs.getwriter("utf-8")( sys.stdout if is_python2 else sys.stdout.buffer diff --git a/egs/aidatatang_200zh/ASR/prepare.sh b/egs/aidatatang_200zh/ASR/prepare.sh index 039951354..4749e1b7f 100755 --- a/egs/aidatatang_200zh/ASR/prepare.sh +++ b/egs/aidatatang_200zh/ASR/prepare.sh @@ -106,11 +106,10 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then if [ ! -f $lang_char_dir/words.txt ]; then ./local/prepare_words.py \ --input-file $lang_char_dir/words_no_ids.txt \ - --output-file $lang_char_dir/words.txt + --output-file $lang_char_dir/words.txt fi if [ ! -f $lang_char_dir/L_disambig.pt ]; then ./local/prepare_char.py fi fi - diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py index 6a5b57e24..8c94f5bea 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -81,10 +81,12 @@ class Aidatatang_200zhAsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", + description=( + "These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc." + ), ) group.add_argument( "--manifest-dir", @@ -96,75 +98,91 @@ class Aidatatang_200zhAsrDataModule: "--max-duration", type=int, default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", + help=( + "Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM." + ), ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", + help=( + "When enabled, the batches will come from buckets of " + "similar duration (saves padding frames)." + ), ) group.add_argument( "--num-buckets", type=int, default=300, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", + help=( + "The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets)." + ), ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", + help=( + "When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding." + ), ) group.add_argument( "--duration-factor", type=float, default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", + help=( + "Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch." + ), ) group.add_argument( "--gap", type=float, default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", + help=( + "The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used." + ), ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", + help=( + "When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available." + ), ) group.add_argument( "--shuffle", type=str2bool, default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", + help=( + "When enabled (=default), the examples will be shuffled for each epoch." + ), ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", + help=( + "When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it." + ), ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that " - "collect the batches.", + help="The number of training dataloader workers that collect the batches.", ) group.add_argument( @@ -178,18 +196,22 @@ class Aidatatang_200zhAsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", + help=( + "Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp." + ), ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", + help=( + "When enabled, select noise from MUSAN and mix it" + "with training dataset. " + ), ) def train_dataloaders( @@ -205,24 +227,20 @@ class Aidatatang_200zhAsrDataModule: The state dict for the training sampler. """ logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - f"Using cut concatenation with duration factor " + "Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -237,9 +255,7 @@ class Aidatatang_200zhAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -282,9 +298,7 @@ class Aidatatang_200zhAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -340,9 +354,7 @@ class Aidatatang_200zhAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py index f0407f429..3f582ef04 100755 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py @@ -69,11 +69,7 @@ from beam_search import ( ) from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -92,25 +88,30 @@ def get_parser(): "--epoch", type=int, default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--batch", type=int, default=None, - help="It specifies the batch checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the batch checkpoint to use for decoding." + "Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -192,8 +193,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -249,9 +249,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -266,10 +264,7 @@ def decode_one_batch( ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -315,11 +310,7 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -390,9 +381,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -425,8 +414,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py index 00b54c39f..34f4d3ddf 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py @@ -62,17 +62,20 @@ def get_parser(): "--epoch", type=int, default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -103,8 +106,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) return parser @@ -173,9 +175,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py index eb5e6b0d4..3c96ed07b 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py @@ -85,9 +85,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -112,10 +114,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -162,8 +166,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -193,10 +196,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -257,9 +259,7 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -284,10 +284,7 @@ def main(): ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -339,9 +336,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py index d46838b68..c7b1a4266 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py @@ -81,9 +81,7 @@ from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] os.environ["CUDA_LAUNCH_BLOCKING"] = "1" @@ -187,42 +185,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -542,22 +543,15 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -711,9 +705,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -813,7 +805,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/aishell/ASR/conformer_ctc/conformer.py b/egs/aishell/ASR/conformer_ctc/conformer.py index cb7205e51..f5b5873b4 100644 --- a/egs/aishell/ASR/conformer_ctc/conformer.py +++ b/egs/aishell/ASR/conformer_ctc/conformer.py @@ -157,9 +157,7 @@ class ConformerEncoderLayer(nn.Module): normalize_before: bool = True, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -177,18 +175,14 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.norm_ff_macaron = nn.LayerNorm( - d_model - ) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm( - d_model - ) # for the final output of the block + self.norm_final = nn.LayerNorm(d_model) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -222,9 +216,7 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) - ) + src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -343,9 +335,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -361,9 +351,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -633,9 +621,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -703,33 +691,25 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor" + " instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -766,9 +746,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -780,9 +758,7 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -816,13 +792,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -845,9 +817,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/aishell/ASR/conformer_ctc/decode.py b/egs/aishell/ASR/conformer_ctc/decode.py index 751b7d5b5..a30fa52df 100755 --- a/egs/aishell/ASR/conformer_ctc/decode.py +++ b/egs/aishell/ASR/conformer_ctc/decode.py @@ -58,16 +58,19 @@ def get_parser(): "--epoch", type=int, default=49, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=20, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -401,9 +404,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -431,9 +432,7 @@ def save_results( # we compute CER for aishell dataset. results_char = [] for res in results: - results_char.append( - (res[0], list("".join(res[1])), list("".join(res[2]))) - ) + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=enable_log @@ -441,9 +440,7 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt" @@ -562,9 +559,7 @@ def main(): eos_id=eos_id, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/aishell/ASR/conformer_ctc/export.py b/egs/aishell/ASR/conformer_ctc/export.py index 42b8c29e7..9ee405e8b 100644 --- a/egs/aishell/ASR/conformer_ctc/export.py +++ b/egs/aishell/ASR/conformer_ctc/export.py @@ -40,17 +40,20 @@ def get_parser(): "--epoch", type=int, default=84, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=25, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -157,9 +160,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/conformer_ctc/pretrained.py b/egs/aishell/ASR/conformer_ctc/pretrained.py index 27776bc24..e3d5a20e3 100755 --- a/egs/aishell/ASR/conformer_ctc/pretrained.py +++ b/egs/aishell/ASR/conformer_ctc/pretrained.py @@ -46,27 +46,29 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( "--tokens-file", type=str, - help="Path to tokens.txt" "Used only when method is ctc-decoding", + help="Path to tokens.txtUsed only when method is ctc-decoding", ) parser.add_argument( "--words-file", type=str, - help="Path to words.txt" "Used when method is NOT ctc-decoding", + help="Path to words.txtUsed when method is NOT ctc-decoding", ) parser.add_argument( "--HLG", type=str, - help="Path to HLG.pt." "Used when method is NOT ctc-decoding", + help="Path to HLG.pt.Used when method is NOT ctc-decoding", ) parser.add_argument( @@ -163,10 +165,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) return parser @@ -210,10 +214,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -274,9 +277,7 @@ def main(): logging.info("Decoding started") features = fbank(waves) - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) # Note: We don't use key padding mask for attention during decoding with torch.no_grad(): @@ -371,9 +372,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/conformer_ctc/subsampling.py b/egs/aishell/ASR/conformer_ctc/subsampling.py index 542fb0364..8e0f73d05 100644 --- a/egs/aishell/ASR/conformer_ctc/subsampling.py +++ b/egs/aishell/ASR/conformer_ctc/subsampling.py @@ -42,13 +42,9 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), - nn.Conv2d( - in_channels=odim, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -132,17 +128,13 @@ class VggSubsampling(nn.Module): ) ) layers.append( - torch.nn.MaxPool2d( - kernel_size=2, stride=2, padding=0, ceil_mode=True - ) + torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear( - block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim - ) + self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/aishell/ASR/conformer_ctc/test_subsampling.py b/egs/aishell/ASR/conformer_ctc/test_subsampling.py index e3361d0c9..81fa234dd 100755 --- a/egs/aishell/ASR/conformer_ctc/test_subsampling.py +++ b/egs/aishell/ASR/conformer_ctc/test_subsampling.py @@ -16,9 +16,8 @@ # limitations under the License. -from subsampling import Conv2dSubsampling -from subsampling import VggSubsampling import torch +from subsampling import Conv2dSubsampling, VggSubsampling def test_conv2d_subsampling(): diff --git a/egs/aishell/ASR/conformer_ctc/train.py b/egs/aishell/ASR/conformer_ctc/train.py index a228cc1fe..c2cbe6e3b 100755 --- a/egs/aishell/ASR/conformer_ctc/train.py +++ b/egs/aishell/ASR/conformer_ctc/train.py @@ -382,9 +382,7 @@ def compute_loss( # # See https://github.com/k2-fsa/icefall/issues/97 # for more details - unsorted_token_ids = graph_compiler.texts_to_ids( - supervisions["text"] - ) + unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) att_loss = mmodel.decoder_forward( encoder_memory, memory_mask, @@ -520,9 +518,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -630,9 +626,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/aishell/ASR/conformer_ctc/transformer.py b/egs/aishell/ASR/conformer_ctc/transformer.py index f93914aaa..a3e50e385 100644 --- a/egs/aishell/ASR/conformer_ctc/transformer.py +++ b/egs/aishell/ASR/conformer_ctc/transformer.py @@ -149,9 +149,7 @@ class Transformer(nn.Module): norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear( - d_model, self.decoder_num_class - ) + self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) self.decoder_criterion = LabelSmoothingLoss() else: @@ -183,9 +181,7 @@ class Transformer(nn.Module): x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) - encoder_memory, memory_key_padding_mask = self.run_encoder( - x, supervision - ) + encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -266,23 +262,17 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -343,23 +333,17 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -632,9 +616,7 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) class PositionalEncoding(nn.Module): @@ -836,9 +818,7 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [ - 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) - ] + lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -859,9 +839,7 @@ def encoder_padding_mask( return mask -def decoder_padding_mask( - ys_pad: torch.Tensor, ignore_id: int = -1 -) -> torch.Tensor: +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/aishell/ASR/conformer_mmi/conformer.py b/egs/aishell/ASR/conformer_mmi/conformer.py index cb7205e51..f5b5873b4 100644 --- a/egs/aishell/ASR/conformer_mmi/conformer.py +++ b/egs/aishell/ASR/conformer_mmi/conformer.py @@ -157,9 +157,7 @@ class ConformerEncoderLayer(nn.Module): normalize_before: bool = True, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -177,18 +175,14 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.norm_ff_macaron = nn.LayerNorm( - d_model - ) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm( - d_model - ) # for the final output of the block + self.norm_final = nn.LayerNorm(d_model) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -222,9 +216,7 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) - ) + src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -343,9 +335,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -361,9 +351,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -633,9 +621,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -703,33 +691,25 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor" + " instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -766,9 +746,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -780,9 +758,7 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -816,13 +792,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -845,9 +817,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/aishell/ASR/conformer_mmi/decode.py b/egs/aishell/ASR/conformer_mmi/decode.py index 4db367e36..a43183063 100755 --- a/egs/aishell/ASR/conformer_mmi/decode.py +++ b/egs/aishell/ASR/conformer_mmi/decode.py @@ -59,16 +59,19 @@ def get_parser(): "--epoch", type=int, default=49, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=20, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -413,9 +416,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -443,9 +444,7 @@ def save_results( # we compute CER for aishell dataset. results_char = [] for res in results: - results_char.append( - (res[0], list("".join(res[1])), list("".join(res[2]))) - ) + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=enable_log @@ -453,9 +452,7 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt" @@ -550,9 +547,7 @@ def main(): if params.export: logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save( - {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" - ) + torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") return model.to(device) @@ -581,9 +576,7 @@ def main(): eos_id=eos_id, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/aishell/ASR/conformer_mmi/subsampling.py b/egs/aishell/ASR/conformer_mmi/subsampling.py index 720ed6c22..398837a46 100644 --- a/egs/aishell/ASR/conformer_mmi/subsampling.py +++ b/egs/aishell/ASR/conformer_mmi/subsampling.py @@ -42,13 +42,9 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), - nn.Conv2d( - in_channels=odim, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -132,17 +128,13 @@ class VggSubsampling(nn.Module): ) ) layers.append( - torch.nn.MaxPool2d( - kernel_size=2, stride=2, padding=0, ceil_mode=True - ) + torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear( - block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim - ) + self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/aishell/ASR/conformer_mmi/train.py b/egs/aishell/ASR/conformer_mmi/train.py index 685831d09..09cd6e60c 100755 --- a/egs/aishell/ASR/conformer_mmi/train.py +++ b/egs/aishell/ASR/conformer_mmi/train.py @@ -511,9 +511,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -625,9 +623,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/aishell/ASR/conformer_mmi/transformer.py b/egs/aishell/ASR/conformer_mmi/transformer.py index f93914aaa..a3e50e385 100644 --- a/egs/aishell/ASR/conformer_mmi/transformer.py +++ b/egs/aishell/ASR/conformer_mmi/transformer.py @@ -149,9 +149,7 @@ class Transformer(nn.Module): norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear( - d_model, self.decoder_num_class - ) + self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) self.decoder_criterion = LabelSmoothingLoss() else: @@ -183,9 +181,7 @@ class Transformer(nn.Module): x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) - encoder_memory, memory_key_padding_mask = self.run_encoder( - x, supervision - ) + encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -266,23 +262,17 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -343,23 +333,17 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -632,9 +616,7 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) class PositionalEncoding(nn.Module): @@ -836,9 +818,7 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [ - 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) - ] + lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -859,9 +839,7 @@ def encoder_padding_mask( return mask -def decoder_padding_mask( - ys_pad: torch.Tensor, ignore_id: int = -1 -) -> torch.Tensor: +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py b/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py index 42700a972..037971927 100755 --- a/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py +++ b/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py @@ -87,9 +87,7 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80): ) if "train" in partition: cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -116,9 +114,7 @@ def get_args(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/aishell/ASR/local/compute_fbank_aishell.py b/egs/aishell/ASR/local/compute_fbank_aishell.py index deab6c809..115ca1031 100755 --- a/egs/aishell/ASR/local/compute_fbank_aishell.py +++ b/egs/aishell/ASR/local/compute_fbank_aishell.py @@ -83,9 +83,7 @@ def compute_fbank_aishell(num_mel_bins: int = 80): ) if "train" in partition: cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -111,9 +109,7 @@ def get_args(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/aishell/ASR/local/prepare_char.py b/egs/aishell/ASR/local/prepare_char.py index d9e47d17a..6b440dfb3 100755 --- a/egs/aishell/ASR/local/prepare_char.py +++ b/egs/aishell/ASR/local/prepare_char.py @@ -86,9 +86,7 @@ def lexicon_to_fst_no_sil( cur_state = loop_state word = word2id[word] - pieces = [ - token2id[i] if i in token2id else token2id[""] for i in pieces - ] + pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces] for i in range(len(pieces) - 1): w = word if i == 0 else eps @@ -142,9 +140,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool: return False -def generate_lexicon( - token_sym_table: Dict[str, int], words: List[str] -) -> Lexicon: +def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon: """Generate a lexicon from a word list and token_sym_table. Args: diff --git a/egs/aishell/ASR/local/prepare_lang.py b/egs/aishell/ASR/local/prepare_lang.py index e5ae89ec4..c8cf9b881 100755 --- a/egs/aishell/ASR/local/prepare_lang.py +++ b/egs/aishell/ASR/local/prepare_lang.py @@ -317,9 +317,7 @@ def lexicon_to_fst( def get_args(): parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", type=str, help="The lang dir, data/lang_phone" - ) + parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone") return parser.parse_args() diff --git a/egs/aishell/ASR/local/test_prepare_lang.py b/egs/aishell/ASR/local/test_prepare_lang.py index d4cf62bba..74e025ad7 100755 --- a/egs/aishell/ASR/local/test_prepare_lang.py +++ b/egs/aishell/ASR/local/test_prepare_lang.py @@ -88,9 +88,7 @@ def test_read_lexicon(filename: str): fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.draw("L.pdf", title="L") - fsa_disambig = lexicon_to_fst( - lexicon_disambig, phone2id=phone2id, word2id=word2id - ) + fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id) fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.draw("L_disambig.pdf", title="L_disambig") diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py index a12934d55..ae926ec66 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py @@ -76,11 +76,7 @@ from beam_search import ( ) from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -118,9 +114,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -188,8 +186,7 @@ def get_parser(): "--context-size", type=int, default=1, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -249,9 +246,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) if params.decoding_method == "fast_beam_search": hyp_tokens = fast_beam_search_one_best( @@ -263,10 +258,7 @@ def decode_one_batch( max_contexts=params.max_contexts, max_states=params.max_states, ) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -310,11 +302,7 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -387,9 +375,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -415,9 +401,7 @@ def save_results( # we compute CER for aishell dataset. results_char = [] for res in results: - results_char.append( - (res[0], list("".join(res[1])), list("".join(res[2]))) - ) + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=True @@ -428,8 +412,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -473,9 +456,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -504,8 +485,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/export.py b/egs/aishell/ASR/pruned_transducer_stateless2/export.py index feababdd2..5f6888db4 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless2/export.py +++ b/egs/aishell/ASR/pruned_transducer_stateless2/export.py @@ -50,11 +50,7 @@ from pathlib import Path import torch from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.lexicon import Lexicon from icefall.utils import str2bool @@ -87,9 +83,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -120,8 +118,7 @@ def get_parser(): "--context-size", type=int, default=1, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) add_model_arguments(parser) @@ -157,8 +154,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -191,9 +187,7 @@ def main(): model.__class__.forward = torch.jit.ignore(model.__class__.forward) logging.info("Using torch.jit.script") model = torch.jit.script(model) - filename = ( - params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt" - ) + filename = params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt" model.save(str(filename)) logging.info(f"Saved to {filename}") else: @@ -201,17 +195,14 @@ def main(): # Save it using a format so that it can be loaded # by :func:`load_checkpoint` filename = ( - params.exp_dir - / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt" + params.exp_dir / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt" ) torch.save({"model": model.state_dict()}, str(filename)) logging.info(f"Saved to {filename}") if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py index 3c38e5db7..f754a7b9e 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py @@ -87,9 +87,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -115,10 +117,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -165,15 +169,16 @@ def get_parser(): "--context-size", type=int, default=1, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", type=int, default=1, - help="Maximum number of symbols per frame. " - "Use only when --method is greedy_search", + help=( + "Maximum number of symbols per frame. " + "Use only when --method is greedy_search" + ), ) add_model_arguments(parser) @@ -196,10 +201,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -256,13 +260,9 @@ def main(): feature_lens = [f.size(0) for f in features] feature_lens = torch.tensor(feature_lens, device=device) - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens) num_waves = encoder_out.size(0) hyp_list = [] @@ -310,9 +310,7 @@ def main(): beam=params.beam_size, ) else: - raise ValueError( - f"Unsupported decoding method: {params.method}" - ) + raise ValueError(f"Unsupported decoding method: {params.method}") hyp_list.append(hyp) hyps = [] @@ -329,9 +327,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/train.py b/egs/aishell/ASR/pruned_transducer_stateless2/train.py index 97d892754..66ca23035 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless2/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless2/train.py @@ -49,7 +49,6 @@ import optim import torch import torch.multiprocessing as mp import torch.nn as nn - from asr_datamodule import AishellAsrDataModule from conformer import Conformer from decoder import Decoder @@ -75,9 +74,7 @@ from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -203,8 +200,7 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need " - "to be changed.", + help="The initial learning rate. This value should not need to be changed.", ) parser.add_argument( @@ -227,42 +223,45 @@ def get_parser(): "--context-size", type=int, default=1, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -561,11 +560,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -593,23 +588,16 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -725,9 +713,7 @@ def train_one_epoch( scaler.update() optimizer.zero_grad() except: # noqa - display_and_save_batch( - batch, params=params, graph_compiler=graph_compiler - ) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) raise if params.print_diagnostics and batch_idx == 5: @@ -891,7 +877,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1029,9 +1015,7 @@ def scan_pessimistic_batches_for_oom( f"Failing criterion: {criterion} " f"(={crit_values[criterion]}) ..." ) - display_and_save_batch( - batch, params=params, graph_compiler=graph_compiler - ) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) raise diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py index d159e420b..6c505940d 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py @@ -121,20 +121,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -202,8 +206,7 @@ def get_parser(): "--context-size", type=int, default=1, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -263,9 +266,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) if params.decoding_method == "fast_beam_search": hyp_tokens = fast_beam_search_one_best( @@ -277,10 +278,7 @@ def decode_one_batch( max_contexts=params.max_contexts, max_states=params.max_states, ) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -324,11 +322,7 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -401,9 +395,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -429,9 +421,7 @@ def save_results( # we compute CER for aishell dataset. results_char = [] for res in results: - results_char.append( - (res[0], list("".join(res[1])), list("".join(res[2]))) - ) + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=True @@ -442,8 +432,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tCER", file=f) @@ -488,9 +477,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -518,13 +505,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -551,13 +537,12 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -586,7 +571,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/export.py b/egs/aishell/ASR/pruned_transducer_stateless3/export.py index 566902a85..e5a5d7c77 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/export.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/export.py @@ -88,20 +88,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -132,8 +136,7 @@ def get_parser(): "--context-size", type=int, default=1, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) add_model_arguments(parser) @@ -166,13 +169,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -195,13 +197,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -229,7 +230,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -252,9 +253,7 @@ def main(): model.__class__.forward = torch.jit.ignore(model.__class__.forward) logging.info("Using torch.jit.script") model = torch.jit.script(model) - filename = ( - params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt" - ) + filename = params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt" model.save(str(filename)) logging.info(f"Saved to {filename}") else: @@ -262,17 +261,14 @@ def main(): # Save it using a format so that it can be loaded # by :func:`load_checkpoint` filename = ( - params.exp_dir - / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt" + params.exp_dir / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt" ) torch.save({"model": model.state_dict()}, str(filename)) logging.info(f"Saved to {filename}") if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/model.py b/egs/aishell/ASR/pruned_transducer_stateless3/model.py index e150e8230..a4dda0d6d 100644 --- a/egs/aishell/ASR/pruned_transducer_stateless3/model.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/model.py @@ -84,9 +84,7 @@ class Transducer(nn.Module): self.decoder_datatang = decoder_datatang self.joiner_datatang = joiner_datatang - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) if decoder_datatang is not None: @@ -179,9 +177,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = encoder_out_lens diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py index 04a0a882a..109879952 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py @@ -87,9 +87,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -115,10 +117,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -165,15 +169,16 @@ def get_parser(): "--context-size", type=int, default=1, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", type=int, default=1, - help="Maximum number of symbols per frame. " - "Use only when --method is greedy_search", + help=( + "Maximum number of symbols per frame. " + "Use only when --method is greedy_search" + ), ) add_model_arguments(parser) @@ -196,10 +201,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -257,13 +261,9 @@ def main(): feature_lens = [f.size(0) for f in features] feature_lens = torch.tensor(feature_lens, device=device) - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens) num_waves = encoder_out.size(0) hyp_list = [] @@ -311,9 +311,7 @@ def main(): beam=params.beam_size, ) else: - raise ValueError( - f"Unsupported decoding method: {params.method}" - ) + raise ValueError(f"Unsupported decoding method: {params.method}") hyp_list.append(hyp) hyps = [] @@ -330,9 +328,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/train.py b/egs/aishell/ASR/pruned_transducer_stateless3/train.py index feaef5cf6..b24f533ff 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/train.py @@ -96,9 +96,7 @@ from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -224,8 +222,7 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need " - "to be changed.", + help="The initial learning rate. This value should not need to be changed.", ) parser.add_argument( @@ -248,42 +245,45 @@ def get_parser(): "--context-size", type=int, default=1, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -635,11 +635,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -670,23 +666,16 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -824,9 +813,7 @@ def train_one_epoch( ) # summary stats if datatang_train_dl is not None: - tot_loss = ( - tot_loss * (1 - 1 / params.reset_interval) - ) + loss_info + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info if aishell: aishell_tot_loss = ( @@ -847,9 +834,7 @@ def train_one_epoch( scaler.update() optimizer.zero_grad() except: # noqa - display_and_save_batch( - batch, params=params, graph_compiler=graph_compiler - ) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) raise if params.print_diagnostics and batch_idx == 5: @@ -892,9 +877,7 @@ def train_one_epoch( cur_lr = scheduler.get_last_lr()[0] if datatang_train_dl is not None: datatang_str = f"datatang_tot_loss[{datatang_tot_loss}], " - tot_loss_str = ( - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - ) + tot_loss_str = f"tot_loss[{tot_loss}], batch size: {batch_size}, " else: tot_loss_str = "" datatang_str = "" @@ -1067,7 +1050,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1076,9 +1059,7 @@ def run(rank, world_size, args): train_cuts = filter_short_and_long_utterances(train_cuts) if args.enable_musan: - cuts_musan = load_manifest( - Path(args.manifest_dir) / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") else: cuts_musan = None @@ -1093,9 +1074,7 @@ def run(rank, world_size, args): if params.datatang_prob > 0: datatang = AIDatatang200zh(manifest_dir=args.manifest_dir) train_datatang_cuts = datatang.train_cuts() - train_datatang_cuts = filter_short_and_long_utterances( - train_datatang_cuts - ) + train_datatang_cuts = filter_short_and_long_utterances(train_datatang_cuts) train_datatang_cuts = train_datatang_cuts.repeat(times=None) datatang_train_dl = asr_datamodule.train_dataloaders( train_datatang_cuts, @@ -1249,9 +1228,7 @@ def scan_pessimistic_batches_for_oom( f"Failing criterion: {criterion} " f"(={crit_values[criterion]}) ..." ) - display_and_save_batch( - batch, params=params, graph_compiler=graph_compiler - ) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) raise diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py index d24ba6bb7..12ae6e7d4 100644 --- a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -64,10 +64,12 @@ class AishellAsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", + description=( + "These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc." + ), ) group.add_argument( "--manifest-dir", @@ -79,59 +81,74 @@ class AishellAsrDataModule: "--max-duration", type=int, default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", + help=( + "Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM." + ), ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", + help=( + "When enabled, the batches will come from buckets of " + "similar duration (saves padding frames)." + ), ) group.add_argument( "--num-buckets", type=int, default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", + help=( + "The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets)." + ), ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", + help=( + "When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding." + ), ) group.add_argument( "--duration-factor", type=float, default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", + help=( + "Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch." + ), ) group.add_argument( "--gap", type=float, default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", + help=( + "The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used." + ), ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", + help=( + "When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available." + ), ) group.add_argument( "--shuffle", type=str2bool, default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", + help=( + "When enabled (=default), the examples will be shuffled for each epoch." + ), ) group.add_argument( "--drop-last", @@ -143,17 +160,18 @@ class AishellAsrDataModule: "--return-cuts", type=str2bool, default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", + help=( + "When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it." + ), ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that " - "collect the batches.", + help="The number of training dataloader workers that collect the batches.", ) group.add_argument( @@ -167,40 +185,40 @@ class AishellAsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", + help=( + "Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp." + ), ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", + help=( + "When enabled, select noise from MUSAN and mix it" + "with training dataset. " + ), ) def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - f"Using cut concatenation with duration factor " + "Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -215,9 +233,7 @@ class AishellAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -260,9 +276,7 @@ class AishellAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -308,9 +322,7 @@ class AishellAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: @@ -366,13 +378,9 @@ class AishellAsrDataModule: @lru_cache() def valid_cuts(self) -> CutSet: logging.info("About to get dev cuts") - return load_manifest_lazy( - self.args.manifest_dir / "aishell_cuts_dev.jsonl.gz" - ) + return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_dev.jsonl.gz") @lru_cache() def test_cuts(self) -> List[CutSet]: logging.info("About to get test cuts") - return load_manifest_lazy( - self.args.manifest_dir / "aishell_cuts_test.jsonl.gz" - ) + return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_test.jsonl.gz") diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py b/egs/aishell/ASR/tdnn_lstm_ctc/decode.py index 66b734fc4..8ef247438 100755 --- a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/decode.py @@ -49,16 +49,19 @@ def get_parser(): "--epoch", type=int, default=19, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=5, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( "--method", @@ -265,9 +268,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -289,9 +290,7 @@ def save_results( # We compute CER for aishell dataset. results_char = [] for res in results: - results_char.append( - (res[0], list("".join(res[1])), list("".join(res[2]))) - ) + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats(f, f"{test_set_name}-{key}", results_char) test_set_wers[key] = wer @@ -335,9 +334,7 @@ def main(): logging.info(f"device: {device}") - HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") - ) + HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")) HLG = HLG.to(device) assert HLG.requires_grad is False @@ -362,9 +359,7 @@ def main(): if params.export: logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save( - {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" - ) + torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") model.to(device) model.eval() @@ -392,9 +387,7 @@ def main(): lexicon=lexicon, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/model.py b/egs/aishell/ASR/tdnn_lstm_ctc/model.py index 5e04c11b4..1731e1ebe 100644 --- a/egs/aishell/ASR/tdnn_lstm_ctc/model.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/model.py @@ -66,10 +66,7 @@ class TdnnLstm(nn.Module): nn.BatchNorm1d(num_features=500, affine=False), ) self.lstms = nn.ModuleList( - [ - nn.LSTM(input_size=500, hidden_size=500, num_layers=1) - for _ in range(5) - ] + [nn.LSTM(input_size=500, hidden_size=500, num_layers=1) for _ in range(5)] ) self.lstm_bnorms = nn.ModuleList( [nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)] diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py b/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py index 9bd810809..52f9410cf 100644 --- a/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py @@ -41,9 +41,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -53,9 +55,7 @@ def get_parser(): help="Path to words.txt", ) - parser.add_argument( - "--HLG", type=str, required=True, help="Path to HLG.pt." - ) + parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.") parser.add_argument( "--method", @@ -71,10 +71,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) return parser @@ -112,10 +114,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -173,9 +174,7 @@ def main(): logging.info("Decoding started") features = fbank(waves) - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) features = features.permute(0, 2, 1) # now features is [N, C, T] with torch.no_grad(): @@ -219,9 +218,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/train.py b/egs/aishell/ASR/tdnn_lstm_ctc/train.py index 7619b0551..e574cf89b 100755 --- a/egs/aishell/ASR/tdnn_lstm_ctc/train.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/train.py @@ -49,12 +49,7 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - encode_supervisions, - setup_logger, - str2bool, -) +from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool def get_parser(): diff --git a/egs/aishell/ASR/transducer_stateless/beam_search.py b/egs/aishell/ASR/transducer_stateless/beam_search.py index 9ed9b2ad1..de0a8d0f5 100644 --- a/egs/aishell/ASR/transducer_stateless/beam_search.py +++ b/egs/aishell/ASR/transducer_stateless/beam_search.py @@ -47,9 +47,9 @@ def greedy_search( device = model.device - decoder_input = torch.tensor( - [blank_id] * context_size, device=device - ).reshape(1, context_size) + decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape( + 1, context_size + ) decoder_out = model.decoder(decoder_input, need_pad=False) @@ -81,9 +81,9 @@ def greedy_search( y = logits.argmax().item() if y != blank_id: hyp.append(y) - decoder_input = torch.tensor( - [hyp[-context_size:]], device=device - ).reshape(1, context_size) + decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( + 1, context_size + ) decoder_out = model.decoder(decoder_input, need_pad=False) @@ -157,9 +157,7 @@ class HypothesisList(object): """ if length_norm: - return max( - self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) - ) + return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) @@ -246,9 +244,9 @@ def beam_search( device = model.device - decoder_input = torch.tensor( - [blank_id] * context_size, device=device - ).reshape(1, context_size) + decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape( + 1, context_size + ) decoder_out = model.decoder(decoder_input, need_pad=False) diff --git a/egs/aishell/ASR/transducer_stateless/conformer.py b/egs/aishell/ASR/transducer_stateless/conformer.py index 64114253d..e26c6c385 100644 --- a/egs/aishell/ASR/transducer_stateless/conformer.py +++ b/egs/aishell/ASR/transducer_stateless/conformer.py @@ -155,9 +155,7 @@ class ConformerEncoderLayer(nn.Module): normalize_before: bool = True, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -175,18 +173,14 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.norm_ff_macaron = nn.LayerNorm( - d_model - ) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm( - d_model - ) # for the final output of the block + self.norm_final = nn.LayerNorm(d_model) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -220,9 +214,7 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) - ) + src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -341,9 +333,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -359,9 +349,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -631,9 +619,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -701,33 +689,25 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor" + " instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -764,9 +744,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -778,9 +756,7 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -814,13 +790,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -843,9 +815,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/aishell/ASR/transducer_stateless/decode.py b/egs/aishell/ASR/transducer_stateless/decode.py index 780b0c4bb..1f7bb14e1 100755 --- a/egs/aishell/ASR/transducer_stateless/decode.py +++ b/egs/aishell/ASR/transducer_stateless/decode.py @@ -52,16 +52,19 @@ def get_parser(): "--epoch", type=int, default=30, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=10, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -99,8 +102,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -227,9 +229,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] batch_size = encoder_out.size(0) @@ -248,9 +248,7 @@ def decode_one_batch( model=model, encoder_out=encoder_out_i, beam=params.beam_size ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") hyps.append([lexicon.token_table[i] for i in hyp]) if params.decoding_method == "greedy_search": @@ -319,9 +317,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -346,9 +342,7 @@ def save_results( # we compute CER for aishell dataset. results_char = [] for res in results: - results_char.append( - (res[0], list("".join(res[1])), list("".join(res[2]))) - ) + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=True @@ -359,8 +353,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tCER", file=f) @@ -430,9 +423,7 @@ def main(): if params.export: logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save( - {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" - ) + torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") return model.to(device) diff --git a/egs/aishell/ASR/transducer_stateless/decoder.py b/egs/aishell/ASR/transducer_stateless/decoder.py index c2c6552a9..70e9e6c96 100644 --- a/egs/aishell/ASR/transducer_stateless/decoder.py +++ b/egs/aishell/ASR/transducer_stateless/decoder.py @@ -86,9 +86,7 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/aishell/ASR/transducer_stateless/export.py b/egs/aishell/ASR/transducer_stateless/export.py index 4c6519b96..e35b26fe0 100755 --- a/egs/aishell/ASR/transducer_stateless/export.py +++ b/egs/aishell/ASR/transducer_stateless/export.py @@ -69,17 +69,20 @@ def get_parser(): "--epoch", type=int, default=20, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=10, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -110,8 +113,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) return parser @@ -243,9 +245,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/transducer_stateless/model.py b/egs/aishell/ASR/transducer_stateless/model.py index 994305fc1..591bbe44f 100644 --- a/egs/aishell/ASR/transducer_stateless/model.py +++ b/egs/aishell/ASR/transducer_stateless/model.py @@ -103,9 +103,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/aishell/ASR/transducer_stateless/pretrained.py b/egs/aishell/ASR/transducer_stateless/pretrained.py index db89c4d67..8effc9815 100755 --- a/egs/aishell/ASR/transducer_stateless/pretrained.py +++ b/egs/aishell/ASR/transducer_stateless/pretrained.py @@ -73,9 +73,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -100,10 +102,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -117,8 +121,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -211,10 +214,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -273,9 +275,7 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -319,9 +319,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/transducer_stateless/train.py b/egs/aishell/ASR/transducer_stateless/train.py index d54157709..62ffff473 100755 --- a/egs/aishell/ASR/transducer_stateless/train.py +++ b/egs/aishell/ASR/transducer_stateless/train.py @@ -126,8 +126,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -389,9 +388,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -504,9 +501,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -625,9 +620,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/aishell/ASR/transducer_stateless/transformer.py b/egs/aishell/ASR/transducer_stateless/transformer.py index e851dcc32..b3ff153c1 100644 --- a/egs/aishell/ASR/transducer_stateless/transformer.py +++ b/egs/aishell/ASR/transducer_stateless/transformer.py @@ -250,9 +250,7 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) class PositionalEncoding(nn.Module): diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py b/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py index 838e53658..76e209f06 100644 --- a/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py @@ -29,10 +29,7 @@ from lhotse.dataset import ( K2SpeechRecognitionDataset, SpecAugment, ) -from lhotse.dataset.input_strategies import ( - OnTheFlyFeatures, - PrecomputedFeatures, -) +from lhotse.dataset.input_strategies import OnTheFlyFeatures, PrecomputedFeatures from torch.utils.data import DataLoader from icefall.utils import str2bool @@ -46,59 +43,69 @@ class AsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", + description=( + "These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc." + ), ) group.add_argument( "--max-duration", type=int, default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", + help=( + "Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM." + ), ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", + help=( + "When enabled, the batches will come from buckets of " + "similar duration (saves padding frames)." + ), ) group.add_argument( "--num-buckets", type=int, default=30, - help="The number of buckets for the DynamicBucketingSampler " - "(you might want to increase it for larger datasets).", + help=( + "The number of buckets for the DynamicBucketingSampler " + "(you might want to increase it for larger datasets)." + ), ) group.add_argument( "--shuffle", type=str2bool, default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", + help=( + "When enabled (=default), the examples will be shuffled for each epoch." + ), ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", + help=( + "When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it." + ), ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that " - "collect the batches.", + help="The number of training dataloader workers that collect the batches.", ) group.add_argument( @@ -112,18 +119,22 @@ class AsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", + help=( + "Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp." + ), ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", + help=( + "When enabled, select noise from MUSAN and mix it" + "with training dataset. " + ), ) group.add_argument( @@ -137,9 +148,11 @@ class AsrDataModule: "--on-the-fly-feats", type=str2bool, default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available. Used only in dev/test CutSet", + help=( + "When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available. Used only in dev/test CutSet" + ), ) def train_dataloaders( @@ -162,9 +175,7 @@ class AsrDataModule: if cuts_musan is not None: logging.info("Enable MUSAN") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -173,9 +184,7 @@ class AsrDataModule: if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -252,9 +261,7 @@ class AsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py index ea3f94fd8..fd4cb8385 100755 --- a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py @@ -93,16 +93,19 @@ def get_parser(): "--epoch", type=int, default=30, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=10, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -170,8 +173,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -227,9 +229,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) if params.decoding_method == "fast_beam_search": hyp_tokens = fast_beam_search_one_best( @@ -241,10 +241,7 @@ def decode_one_batch( max_contexts=params.max_contexts, max_states=params.max_states, ) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -288,11 +285,7 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -365,9 +358,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -393,9 +384,7 @@ def save_results( # we compute CER for aishell dataset. results_char = [] for res in results: - results_char.append( - (res[0], list("".join(res[1])), list("".join(res[2]))) - ) + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=True @@ -406,8 +395,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tCER", file=f) @@ -448,9 +436,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/export.py b/egs/aishell/ASR/transducer_stateless_modified-2/export.py index 3bd2ceb11..32481829c 100755 --- a/egs/aishell/ASR/transducer_stateless_modified-2/export.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/export.py @@ -68,17 +68,20 @@ def get_parser(): "--epoch", type=int, default=20, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=10, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -109,8 +112,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) return parser @@ -241,9 +243,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py index a95a4bc52..55701a007 100755 --- a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py @@ -87,9 +87,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -115,10 +117,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -165,15 +169,16 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", type=int, default=1, - help="Maximum number of symbols per frame. " - "Use only when --method is greedy_search", + help=( + "Maximum number of symbols per frame. " + "Use only when --method is greedy_search" + ), ) return parser @@ -194,10 +199,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -254,13 +258,9 @@ def main(): feature_lens = [f.size(0) for f in features] feature_lens = torch.tensor(feature_lens, device=device) - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens) num_waves = encoder_out.size(0) hyp_list = [] @@ -308,9 +308,7 @@ def main(): beam=params.beam_size, ) else: - raise ValueError( - f"Unsupported decoding method: {params.method}" - ) + raise ValueError(f"Unsupported decoding method: {params.method}") hyp_list.append(hyp) hyps = [] @@ -327,9 +325,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/train.py b/egs/aishell/ASR/transducer_stateless_modified-2/train.py index 225d0d709..8fb7d1e49 100755 --- a/egs/aishell/ASR/transducer_stateless_modified-2/train.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/train.py @@ -149,8 +149,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -168,8 +167,7 @@ def get_parser(): "--datatang-prob", type=float, default=0.2, - help="The probability to select a batch from the " - "aidatatang_200zh dataset", + help="The probability to select a batch from the aidatatang_200zh dataset", ) return parser @@ -449,9 +447,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -605,9 +601,7 @@ def train_one_epoch( f"train/current_{prefix}_", params.batch_idx_train, ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) aishell_tot_loss.write_summary( tb_writer, "train/aishell_tot_", params.batch_idx_train ) @@ -735,9 +729,7 @@ def run(rank, world_size, args): train_datatang_cuts = train_datatang_cuts.repeat(times=None) if args.enable_musan: - cuts_musan = load_manifest( - Path(args.manifest_dir) / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") else: cuts_musan = None @@ -776,9 +768,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/aishell/ASR/transducer_stateless_modified/decode.py b/egs/aishell/ASR/transducer_stateless_modified/decode.py index 65fcda873..1e41942da 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/decode.py +++ b/egs/aishell/ASR/transducer_stateless_modified/decode.py @@ -94,16 +94,19 @@ def get_parser(): "--epoch", type=int, default=30, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=10, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -171,8 +174,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -231,9 +233,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) if params.decoding_method == "fast_beam_search": hyp_tokens = fast_beam_search_one_best( @@ -245,10 +245,7 @@ def decode_one_batch( max_contexts=params.max_contexts, max_states=params.max_states, ) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -292,11 +289,7 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -369,9 +362,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -397,9 +388,7 @@ def save_results( # we compute CER for aishell dataset. results_char = [] for res in results: - results_char.append( - (res[0], list("".join(res[1])), list("".join(res[2]))) - ) + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=True @@ -410,8 +399,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tCER", file=f) @@ -452,9 +440,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" diff --git a/egs/aishell/ASR/transducer_stateless_modified/export.py b/egs/aishell/ASR/transducer_stateless_modified/export.py index 11335a834..ca1d4bd4a 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/export.py +++ b/egs/aishell/ASR/transducer_stateless_modified/export.py @@ -68,17 +68,20 @@ def get_parser(): "--epoch", type=int, default=20, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=10, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -109,8 +112,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) return parser @@ -241,9 +243,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py index 262e822c2..038090461 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py +++ b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py @@ -87,9 +87,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -115,10 +117,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -165,15 +169,16 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", type=int, default=1, - help="Maximum number of symbols per frame. " - "Use only when --method is greedy_search", + help=( + "Maximum number of symbols per frame. " + "Use only when --method is greedy_search" + ), ) return parser @@ -194,10 +199,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -254,13 +258,9 @@ def main(): feature_lens = [f.size(0) for f in features] feature_lens = torch.tensor(feature_lens, device=device) - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens) num_waves = encoder_out.size(0) hyp_list = [] @@ -308,9 +308,7 @@ def main(): beam=params.beam_size, ) else: - raise ValueError( - f"Unsupported decoding method: {params.method}" - ) + raise ValueError(f"Unsupported decoding method: {params.method}") hyp_list.append(hyp) hyps = [] @@ -327,9 +325,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/transducer_stateless_modified/train.py b/egs/aishell/ASR/transducer_stateless_modified/train.py index d3ffccafa..5f116f2bd 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/train.py +++ b/egs/aishell/ASR/transducer_stateless_modified/train.py @@ -142,8 +142,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -414,9 +413,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -529,9 +526,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -657,9 +652,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/aishell2/ASR/local/__init__.py b/egs/aishell2/ASR/local/__init__.py old mode 100755 new mode 100644 diff --git a/egs/aishell2/ASR/local/compute_fbank_aishell2.py b/egs/aishell2/ASR/local/compute_fbank_aishell2.py index d8d3622bd..ec0c584ca 100755 --- a/egs/aishell2/ASR/local/compute_fbank_aishell2.py +++ b/egs/aishell2/ASR/local/compute_fbank_aishell2.py @@ -83,9 +83,7 @@ def compute_fbank_aishell2(num_mel_bins: int = 80): ) if "train" in partition: cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -111,9 +109,7 @@ def get_args(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/__init__.py b/egs/aishell2/ASR/pruned_transducer_stateless5/__init__.py old mode 100755 new mode 100644 diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py old mode 100755 new mode 100644 index b7a21f579..e8966b554 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -76,10 +76,12 @@ class AiShell2AsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", + description=( + "These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc." + ), ) group.add_argument( "--manifest-dir", @@ -91,59 +93,74 @@ class AiShell2AsrDataModule: "--max-duration", type=int, default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", + help=( + "Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM." + ), ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", + help=( + "When enabled, the batches will come from buckets of " + "similar duration (saves padding frames)." + ), ) group.add_argument( "--num-buckets", type=int, default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", + help=( + "The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets)." + ), ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", + help=( + "When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding." + ), ) group.add_argument( "--duration-factor", type=float, default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", + help=( + "Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch." + ), ) group.add_argument( "--gap", type=float, default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", + help=( + "The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used." + ), ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", + help=( + "When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available." + ), ) group.add_argument( "--shuffle", type=str2bool, default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", + help=( + "When enabled (=default), the examples will be shuffled for each epoch." + ), ) group.add_argument( "--drop-last", @@ -155,17 +172,18 @@ class AiShell2AsrDataModule: "--return-cuts", type=str2bool, default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", + help=( + "When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it." + ), ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that " - "collect the batches.", + help="The number of training dataloader workers that collect the batches.", ) group.add_argument( @@ -179,18 +197,22 @@ class AiShell2AsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", + help=( + "Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp." + ), ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", + help=( + "When enabled, select noise from MUSAN and mix it" + "with training dataset. " + ), ) group.add_argument( @@ -216,20 +238,16 @@ class AiShell2AsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - f"Using cut concatenation with duration factor " + "Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -244,9 +262,7 @@ class AiShell2AsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -290,9 +306,7 @@ class AiShell2AsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -348,9 +362,7 @@ class AiShell2AsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: @@ -406,9 +418,7 @@ class AiShell2AsrDataModule: @lru_cache() def valid_cuts(self) -> CutSet: logging.info("About to gen cuts from aishell2_cuts_dev.jsonl.gz") - return load_manifest_lazy( - self.args.manifest_dir / "aishell2_cuts_dev.jsonl.gz" - ) + return load_manifest_lazy(self.args.manifest_dir / "aishell2_cuts_dev.jsonl.gz") @lru_cache() def test_cuts(self) -> CutSet: diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py index 915737f4a..64b64d1b1 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py @@ -168,20 +168,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -269,8 +273,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -348,9 +351,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -409,10 +410,7 @@ def decode_one_batch( ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -538,9 +536,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -573,8 +569,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -625,9 +620,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -661,13 +654,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -690,13 +682,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -724,7 +715,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -749,9 +740,7 @@ def main(): ) decoding_graph.scores *= params.ngram_lm_scale else: - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/export.py b/egs/aishell2/ASR/pruned_transducer_stateless5/export.py index bc7bd71cb..547ce2069 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/export.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/export.py @@ -89,20 +89,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -133,8 +137,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) add_model_arguments(parser) @@ -167,13 +170,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -196,13 +198,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -230,7 +231,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -266,9 +267,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py b/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py index 09de1bece..4b16511e8 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py @@ -81,9 +81,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -109,10 +111,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -159,8 +163,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -191,10 +194,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -254,15 +256,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -334,9 +332,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py index 838a0497f..d37e7bdca 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py @@ -92,9 +92,7 @@ from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -220,8 +218,7 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need " - "to be changed.", + help="The initial learning rate. This value should not need to be changed.", ) parser.add_argument( @@ -244,42 +241,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -603,11 +603,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -636,23 +632,16 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -771,9 +760,7 @@ def train_one_epoch( scaler.update() optimizer.zero_grad() except: # noqa - display_and_save_batch( - batch, params=params, graph_compiler=graph_compiler - ) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) raise if params.print_diagnostics and batch_idx == 5: @@ -829,9 +816,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -939,7 +924,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1104,9 +1089,7 @@ def scan_pessimistic_batches_for_oom( f"Failing criterion: {criterion} " f"(={crit_values[criterion]}) ..." ) - display_and_save_batch( - batch, params=params, graph_compiler=graph_compiler - ) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) raise diff --git a/egs/aishell4/ASR/local/compute_fbank_aishell4.py b/egs/aishell4/ASR/local/compute_fbank_aishell4.py index 3f50d9e3e..400c406f0 100755 --- a/egs/aishell4/ASR/local/compute_fbank_aishell4.py +++ b/egs/aishell4/ASR/local/compute_fbank_aishell4.py @@ -85,9 +85,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80): ) if "train" in partition: cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -120,9 +118,7 @@ def get_args(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/aishell4/ASR/local/prepare_char.py b/egs/aishell4/ASR/local/prepare_char.py index d9e47d17a..6b440dfb3 100755 --- a/egs/aishell4/ASR/local/prepare_char.py +++ b/egs/aishell4/ASR/local/prepare_char.py @@ -86,9 +86,7 @@ def lexicon_to_fst_no_sil( cur_state = loop_state word = word2id[word] - pieces = [ - token2id[i] if i in token2id else token2id[""] for i in pieces - ] + pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces] for i in range(len(pieces) - 1): w = word if i == 0 else eps @@ -142,9 +140,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool: return False -def generate_lexicon( - token_sym_table: Dict[str, int], words: List[str] -) -> Lexicon: +def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon: """Generate a lexicon from a word list and token_sym_table. Args: diff --git a/egs/aishell4/ASR/local/prepare_lang.py b/egs/aishell4/ASR/local/prepare_lang.py index e5ae89ec4..c8cf9b881 100755 --- a/egs/aishell4/ASR/local/prepare_lang.py +++ b/egs/aishell4/ASR/local/prepare_lang.py @@ -317,9 +317,7 @@ def lexicon_to_fst( def get_args(): parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", type=str, help="The lang dir, data/lang_phone" - ) + parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone") return parser.parse_args() diff --git a/egs/aishell4/ASR/local/test_prepare_lang.py b/egs/aishell4/ASR/local/test_prepare_lang.py index d4cf62bba..74e025ad7 100755 --- a/egs/aishell4/ASR/local/test_prepare_lang.py +++ b/egs/aishell4/ASR/local/test_prepare_lang.py @@ -88,9 +88,7 @@ def test_read_lexicon(filename: str): fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.draw("L.pdf", title="L") - fsa_disambig = lexicon_to_fst( - lexicon_disambig, phone2id=phone2id, word2id=word2id - ) + fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id) fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.draw("L_disambig.pdf", title="L_disambig") diff --git a/egs/aishell4/ASR/local/text2token.py b/egs/aishell4/ASR/local/text2token.py index 71be2a613..2be639b7a 100755 --- a/egs/aishell4/ASR/local/text2token.py +++ b/egs/aishell4/ASR/local/text2token.py @@ -50,15 +50,15 @@ def get_parser(): "-n", default=1, type=int, - help="number of characters to split, i.e., \ - aabb -> a a b b with -n 1 and aa bb with -n 2", + help=( + "number of characters to split, i.e., aabb -> a a b" + " b with -n 1 and aa bb with -n 2" + ), ) parser.add_argument( "--skip-ncols", "-s", default=0, type=int, help="skip first n columns" ) - parser.add_argument( - "--space", default="", type=str, help="space symbol" - ) + parser.add_argument("--space", default="", type=str, help="space symbol") parser.add_argument( "--non-lang-syms", "-l", @@ -66,9 +66,7 @@ def get_parser(): type=str, help="list of non-linguistic symobles, e.g., etc.", ) - parser.add_argument( - "text", type=str, default=False, nargs="?", help="input text" - ) + parser.add_argument("text", type=str, default=False, nargs="?", help="input text") parser.add_argument( "--trans_type", "-t", @@ -108,8 +106,7 @@ def token2id( if token_type == "lazy_pinyin": text = lazy_pinyin(chars_list) sub_ids = [ - token_table[txt] if txt in token_table else oov_id - for txt in text + token_table[txt] if txt in token_table else oov_id for txt in text ] ids.append(sub_ids) else: # token_type = "pinyin" @@ -135,9 +132,7 @@ def main(): if args.text: f = codecs.open(args.text, encoding="utf-8") else: - f = codecs.getreader("utf-8")( - sys.stdin if is_python2 else sys.stdin.buffer - ) + f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer) sys.stdout = codecs.getwriter("utf-8")( sys.stdout if is_python2 else sys.stdout.buffer diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py index 7aa53ddda..84c7f0443 100644 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -74,10 +74,12 @@ class Aishell4AsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", + description=( + "These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc." + ), ) group.add_argument( @@ -91,66 +93,81 @@ class Aishell4AsrDataModule: "--max-duration", type=int, default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", + help=( + "Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM." + ), ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", + help=( + "When enabled, the batches will come from buckets of " + "similar duration (saves padding frames)." + ), ) group.add_argument( "--num-buckets", type=int, default=300, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", + help=( + "The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets)." + ), ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", + help=( + "When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding." + ), ) group.add_argument( "--duration-factor", type=float, default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", + help=( + "Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch." + ), ) group.add_argument( "--gap", type=float, default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", + help=( + "The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used." + ), ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", + help=( + "When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available." + ), ) group.add_argument( "--shuffle", type=str2bool, default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", + help=( + "When enabled (=default), the examples will be shuffled for each epoch." + ), ) group.add_argument( @@ -164,17 +181,18 @@ class Aishell4AsrDataModule: "--return-cuts", type=str2bool, default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", + help=( + "When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it." + ), ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that " - "collect the batches.", + help="The number of training dataloader workers that collect the batches.", ) group.add_argument( @@ -188,18 +206,22 @@ class Aishell4AsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", + help=( + "Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp." + ), ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", + help=( + "When enabled, select noise from MUSAN and mix it" + "with training dataset. " + ), ) group.add_argument( @@ -222,24 +244,20 @@ class Aishell4AsrDataModule: The state dict for the training sampler. """ logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - f"Using cut concatenation with duration factor " + "Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -254,9 +272,7 @@ class Aishell4AsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -300,9 +316,7 @@ class Aishell4AsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -359,9 +373,7 @@ class Aishell4AsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py index 14e44c7d9..616a88937 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py @@ -117,20 +117,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -201,8 +205,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -260,9 +263,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -277,10 +278,7 @@ def decode_one_batch( ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -326,11 +324,7 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -401,9 +395,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -436,8 +428,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -480,9 +471,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -510,13 +499,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -543,13 +531,12 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -578,7 +565,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/export.py b/egs/aishell4/ASR/pruned_transducer_stateless5/export.py index 993341131..3c580ff7b 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/export.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/export.py @@ -89,20 +89,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -136,8 +140,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) add_model_arguments(parser) @@ -169,13 +172,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -202,13 +204,12 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -237,7 +238,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -276,9 +277,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py b/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py index 1fa893637..8151442af 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py @@ -94,9 +94,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -122,10 +124,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -172,8 +176,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -204,10 +207,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -266,15 +268,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -306,10 +304,7 @@ def main(): for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -350,9 +345,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py index 0a48b9059..aacd23ecd 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py @@ -85,9 +85,7 @@ from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -213,8 +211,7 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need " - "to be changed.", + help="The initial learning rate. This value should not need to be changed.", ) parser.add_argument( @@ -237,42 +234,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -599,11 +599,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -633,22 +629,15 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -827,9 +816,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -937,7 +924,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py b/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py index af926aa53..96115a230 100755 --- a/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py +++ b/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py @@ -84,9 +84,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80): ) if "train" in partition: cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) cur_num_jobs = num_jobs if ex is None else 80 cur_num_jobs = min(cur_num_jobs, len(cut_set)) @@ -121,9 +119,7 @@ def get_args(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/alimeeting/ASR/local/prepare_char.py b/egs/alimeeting/ASR/local/prepare_char.py index d9e47d17a..6b440dfb3 100755 --- a/egs/alimeeting/ASR/local/prepare_char.py +++ b/egs/alimeeting/ASR/local/prepare_char.py @@ -86,9 +86,7 @@ def lexicon_to_fst_no_sil( cur_state = loop_state word = word2id[word] - pieces = [ - token2id[i] if i in token2id else token2id[""] for i in pieces - ] + pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces] for i in range(len(pieces) - 1): w = word if i == 0 else eps @@ -142,9 +140,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool: return False -def generate_lexicon( - token_sym_table: Dict[str, int], words: List[str] -) -> Lexicon: +def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon: """Generate a lexicon from a word list and token_sym_table. Args: diff --git a/egs/alimeeting/ASR/local/prepare_lang.py b/egs/alimeeting/ASR/local/prepare_lang.py index e5ae89ec4..c8cf9b881 100755 --- a/egs/alimeeting/ASR/local/prepare_lang.py +++ b/egs/alimeeting/ASR/local/prepare_lang.py @@ -317,9 +317,7 @@ def lexicon_to_fst( def get_args(): parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", type=str, help="The lang dir, data/lang_phone" - ) + parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone") return parser.parse_args() diff --git a/egs/alimeeting/ASR/local/test_prepare_lang.py b/egs/alimeeting/ASR/local/test_prepare_lang.py index d4cf62bba..74e025ad7 100755 --- a/egs/alimeeting/ASR/local/test_prepare_lang.py +++ b/egs/alimeeting/ASR/local/test_prepare_lang.py @@ -88,9 +88,7 @@ def test_read_lexicon(filename: str): fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.draw("L.pdf", title="L") - fsa_disambig = lexicon_to_fst( - lexicon_disambig, phone2id=phone2id, word2id=word2id - ) + fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id) fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.draw("L_disambig.pdf", title="L_disambig") diff --git a/egs/alimeeting/ASR/local/text2segments.py b/egs/alimeeting/ASR/local/text2segments.py index 7c1019aa8..27b904fc8 100644 --- a/egs/alimeeting/ASR/local/text2segments.py +++ b/egs/alimeeting/ASR/local/text2segments.py @@ -30,8 +30,8 @@ with word segmenting: import argparse -import paddle import jieba +import paddle from tqdm import tqdm paddle.enable_static() diff --git a/egs/alimeeting/ASR/local/text2token.py b/egs/alimeeting/ASR/local/text2token.py index 71be2a613..2be639b7a 100755 --- a/egs/alimeeting/ASR/local/text2token.py +++ b/egs/alimeeting/ASR/local/text2token.py @@ -50,15 +50,15 @@ def get_parser(): "-n", default=1, type=int, - help="number of characters to split, i.e., \ - aabb -> a a b b with -n 1 and aa bb with -n 2", + help=( + "number of characters to split, i.e., aabb -> a a b" + " b with -n 1 and aa bb with -n 2" + ), ) parser.add_argument( "--skip-ncols", "-s", default=0, type=int, help="skip first n columns" ) - parser.add_argument( - "--space", default="", type=str, help="space symbol" - ) + parser.add_argument("--space", default="", type=str, help="space symbol") parser.add_argument( "--non-lang-syms", "-l", @@ -66,9 +66,7 @@ def get_parser(): type=str, help="list of non-linguistic symobles, e.g., etc.", ) - parser.add_argument( - "text", type=str, default=False, nargs="?", help="input text" - ) + parser.add_argument("text", type=str, default=False, nargs="?", help="input text") parser.add_argument( "--trans_type", "-t", @@ -108,8 +106,7 @@ def token2id( if token_type == "lazy_pinyin": text = lazy_pinyin(chars_list) sub_ids = [ - token_table[txt] if txt in token_table else oov_id - for txt in text + token_table[txt] if txt in token_table else oov_id for txt in text ] ids.append(sub_ids) else: # token_type = "pinyin" @@ -135,9 +132,7 @@ def main(): if args.text: f = codecs.open(args.text, encoding="utf-8") else: - f = codecs.getreader("utf-8")( - sys.stdin if is_python2 else sys.stdin.buffer - ) + f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer) sys.stdout = codecs.getwriter("utf-8")( sys.stdout if is_python2 else sys.stdout.buffer diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py index bf6faad7a..d0467a29e 100644 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -81,10 +81,12 @@ class AlimeetingAsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", + description=( + "These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc." + ), ) group.add_argument( "--manifest-dir", @@ -96,75 +98,91 @@ class AlimeetingAsrDataModule: "--max-duration", type=int, default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", + help=( + "Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM." + ), ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", + help=( + "When enabled, the batches will come from buckets of " + "similar duration (saves padding frames)." + ), ) group.add_argument( "--num-buckets", type=int, default=300, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", + help=( + "The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets)." + ), ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", + help=( + "When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding." + ), ) group.add_argument( "--duration-factor", type=float, default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", + help=( + "Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch." + ), ) group.add_argument( "--gap", type=float, default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", + help=( + "The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used." + ), ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", + help=( + "When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available." + ), ) group.add_argument( "--shuffle", type=str2bool, default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", + help=( + "When enabled (=default), the examples will be shuffled for each epoch." + ), ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", + help=( + "When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it." + ), ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that " - "collect the batches.", + help="The number of training dataloader workers that collect the batches.", ) group.add_argument( @@ -178,18 +196,22 @@ class AlimeetingAsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", + help=( + "Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp." + ), ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", + help=( + "When enabled, select noise from MUSAN and mix it" + "with training dataset. " + ), ) def train_dataloaders( @@ -205,24 +227,20 @@ class AlimeetingAsrDataModule: The state dict for the training sampler. """ logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - f"Using cut concatenation with duration factor " + "Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -237,9 +255,7 @@ class AlimeetingAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -282,9 +298,7 @@ class AlimeetingAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -341,9 +355,7 @@ class AlimeetingAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py index 6358fe970..ffaca1021 100755 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py @@ -70,11 +70,7 @@ from beam_search import ( from lhotse.cut import Cut from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -93,25 +89,30 @@ def get_parser(): "--epoch", type=int, default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--batch", type=int, default=None, - help="It specifies the batch checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the batch checkpoint to use for decoding." + "Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -193,8 +194,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -249,9 +249,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -266,10 +264,7 @@ def decode_one_batch( ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -315,11 +310,7 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -390,9 +381,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -425,8 +414,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -563,8 +551,7 @@ def main(): ) dev_shards = [ - str(path) - for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar"))) + str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar"))) ] cuts_dev_webdataset = CutSet.from_webdataset( dev_shards, @@ -574,8 +561,7 @@ def main(): ) test_shards = [ - str(path) - for path in sorted(glob.glob(os.path.join(test, "shared-*.tar"))) + str(path) for path in sorted(glob.glob(os.path.join(test, "shared-*.tar"))) ] cuts_test_webdataset = CutSet.from_webdataset( test_shards, @@ -588,9 +574,7 @@ def main(): return 1.0 <= c.duration cuts_dev_webdataset = cuts_dev_webdataset.filter(remove_short_and_long_utt) - cuts_test_webdataset = cuts_test_webdataset.filter( - remove_short_and_long_utt - ) + cuts_test_webdataset = cuts_test_webdataset.filter(remove_short_and_long_utt) dev_dl = alimeeting.valid_dataloaders(cuts_dev_webdataset) test_dl = alimeeting.test_dataloaders(cuts_test_webdataset) diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py index 8beec1b8a..482e52d83 100644 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py @@ -62,17 +62,20 @@ def get_parser(): "--epoch", type=int, default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -103,8 +106,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) return parser @@ -173,9 +175,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py index 93b1e1f57..afbf0960a 100644 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py @@ -85,9 +85,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -112,10 +114,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -162,8 +166,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -193,10 +196,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -257,9 +259,7 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -284,10 +284,7 @@ def main(): ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -339,9 +336,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py index 81a0ede7f..158ea9c1b 100644 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py @@ -81,9 +81,7 @@ from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] os.environ["CUDA_LAUNCH_BLOCKING"] = "1" @@ -187,42 +185,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -542,22 +543,15 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -711,9 +705,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -813,7 +805,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/csj/ASR/.gitignore b/egs/csj/ASR/.gitignore index 5d965832e..cd0e20c4c 100644 --- a/egs/csj/ASR/.gitignore +++ b/egs/csj/ASR/.gitignore @@ -5,4 +5,4 @@ notify_tg.py finetune_* misc.ini .vscode/* -offline/* \ No newline at end of file +offline/* diff --git a/egs/csj/ASR/local/compute_fbank_csj.py b/egs/csj/ASR/local/compute_fbank_csj.py index 994dedbdd..036ce925f 100644 --- a/egs/csj/ASR/local/compute_fbank_csj.py +++ b/egs/csj/ASR/local/compute_fbank_csj.py @@ -25,15 +25,10 @@ from random import Random from typing import List, Tuple import torch -from lhotse import ( +from lhotse import ( # fmt: off; See the following for why LilcomChunkyWriter is preferred; https://github.com/k2-fsa/icefall/pull/404; https://github.com/lhotse-speech/lhotse/pull/527; fmt: on CutSet, Fbank, FbankConfig, - # fmt: off - # See the following for why LilcomChunkyWriter is preferred - # https://github.com/k2-fsa/icefall/pull/404 - # https://github.com/lhotse-speech/lhotse/pull/527 - # fmt: on LilcomChunkyWriter, RecordingSet, SupervisionSet, @@ -81,17 +76,13 @@ def make_cutset_blueprints( cut_sets.append((f"eval{i}", cut_set)) # Create train and valid cuts - logging.info( - "Loading, trimming, and shuffling the remaining core+noncore cuts." - ) + logging.info("Loading, trimming, and shuffling the remaining core+noncore cuts.") recording_set = RecordingSet.from_file( manifest_dir / "csj_recordings_core.jsonl.gz" ) + RecordingSet.from_file(manifest_dir / "csj_recordings_noncore.jsonl.gz") supervision_set = SupervisionSet.from_file( manifest_dir / "csj_supervisions_core.jsonl.gz" - ) + SupervisionSet.from_file( - manifest_dir / "csj_supervisions_noncore.jsonl.gz" - ) + ) + SupervisionSet.from_file(manifest_dir / "csj_supervisions_noncore.jsonl.gz") cut_set = CutSet.from_manifests( recordings=recording_set, @@ -101,15 +92,12 @@ def make_cutset_blueprints( cut_set = cut_set.shuffle(Random(RNG_SEED)) logging.info( - "Creating valid and train cuts from core and noncore," - f"split at {split}." + f"Creating valid and train cuts from core and noncore,split at {split}." ) valid_set = CutSet.from_cuts(islice(cut_set, 0, split)) train_set = CutSet.from_cuts(islice(cut_set, split, None)) - train_set = ( - train_set + train_set.perturb_speed(0.9) + train_set.perturb_speed(1.1) - ) + train_set = train_set + train_set.perturb_speed(0.9) + train_set.perturb_speed(1.1) cut_sets.extend([("valid", valid_set), ("train", train_set)]) @@ -122,15 +110,9 @@ def get_args(): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument( - "--manifest-dir", type=Path, help="Path to save manifests" - ) - parser.add_argument( - "--fbank-dir", type=Path, help="Path to save fbank features" - ) - parser.add_argument( - "--split", type=int, default=4000, help="Split at this index" - ) + parser.add_argument("--manifest-dir", type=Path, help="Path to save manifests") + parser.add_argument("--fbank-dir", type=Path, help="Path to save fbank features") + parser.add_argument("--split", type=int, default=4000, help="Split at this index") return parser.parse_args() @@ -141,9 +123,7 @@ def main(): extractor = Fbank(FbankConfig(num_mel_bins=80)) num_jobs = min(16, os.cpu_count()) - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/csj/ASR/local/compute_fbank_musan.py b/egs/csj/ASR/local/compute_fbank_musan.py index 44a33c4eb..f60e62c85 100644 --- a/egs/csj/ASR/local/compute_fbank_musan.py +++ b/egs/csj/ASR/local/compute_fbank_musan.py @@ -26,7 +26,6 @@ from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor - ARGPARSE_DESCRIPTION = """ This file computes fbank features of the musan dataset. It looks for manifests in the directory data/manifests. @@ -84,9 +83,7 @@ def compute_fbank_musan(manifest_dir: Path, fbank_dir: Path): # create chunks of Musan with duration 5 - 10 seconds musan_cuts = ( CutSet.from_manifests( - recordings=combine( - part["recordings"] for part in manifests.values() - ) + recordings=combine(part["recordings"] for part in manifests.values()) ) .cut_into_windows(10.0) .filter(lambda c: c.duration > 5) @@ -107,21 +104,15 @@ def get_args(): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument( - "--manifest-dir", type=Path, help="Path to save manifests" - ) - parser.add_argument( - "--fbank-dir", type=Path, help="Path to save fbank features" - ) + parser.add_argument("--manifest-dir", type=Path, help="Path to save manifests") + parser.add_argument("--fbank-dir", type=Path, help="Path to save fbank features") return parser.parse_args() if __name__ == "__main__": args = get_args() - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_musan(args.manifest_dir, args.fbank_dir) diff --git a/egs/csj/ASR/local/conf/disfluent.ini b/egs/csj/ASR/local/conf/disfluent.ini index eb70673de..c987e72c5 100644 --- a/egs/csj/ASR/local/conf/disfluent.ini +++ b/egs/csj/ASR/local/conf/disfluent.ini @@ -1,17 +1,17 @@ ; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj +; # lhotse prepare csj [SEGMENTS] ; # Allowed period of nonverbal noise. If exceeded, a new segment is created. gap = 0.5 ; # Maximum length of segment (s). maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. +; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = +; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. +; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. +; # If you intend to use a multicharacter string for gap_sym, remember to register the +; # multicharacter string as part of userdef-string in prepare_lang_char.py. +gap_sym = [CONSTANTS] ; # Name of this mode @@ -115,59 +115,59 @@ B^ = 0 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', +; # Example: 'コク(笑 サイ+(D オン))', 笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete -; # Example: '(泣 ドンナニ)' +; # Example: '(泣 ドンナニ)' 泣 = 0 泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete -; # Example: 'シャ(咳 リン) ノ' +; # Example: 'シャ(咳 リン) ノ' 咳 = 0 ; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' 咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete -; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' +; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 ; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' L^ = 0 [REPLACEMENTS] ; # ボーカルフライなどで母音が同定できない場合 - = + = ; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = + = ; # 非語彙的な母音の引き延ばし - = + = ; # 非語彙的な子音の引き延ばし - = + = ; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = +<笑> = ; # 言語音と独立に講演者の咳が生じている場合 -<咳> = +<咳> = ; # 言語音と独立に講演者の息が生じている場合 -<息> = +<息> = ; # 講演者の泣き声 -<泣> = +<泣> = ; # 聴衆(司会者なども含む)の発話 -<フロア発話> = +<フロア発話> = ; # 聴衆の笑い -<フロア笑> = +<フロア笑> = ; # 聴衆の拍手 -<拍手> = +<拍手> = ; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = +<デモ> = ; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = +<ベル> = ; # 転記単位全体が再度読み直された場合 -<朗読間違い> = +<朗読間違い> = ; # 上記以外の音で特に目立った音 -<雑音> = +<雑音> = ; # 0.2秒以上のポーズ -

= +

= ; # Redacted information, for R ; # It is \x00D7 multiplication sign, not your normal 'x' × = × @@ -318,4 +318,3 @@ spk_id = 2 ャ = ǐa ュ = ǐu ョ = ǐo - diff --git a/egs/csj/ASR/local/conf/fluent.ini b/egs/csj/ASR/local/conf/fluent.ini index 5d22f9eb8..f7f27f5bc 100644 --- a/egs/csj/ASR/local/conf/fluent.ini +++ b/egs/csj/ASR/local/conf/fluent.ini @@ -1,17 +1,17 @@ ; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj +; # lhotse prepare csj [SEGMENTS] ; # Allowed period of nonverbal noise. If exceeded, a new segment is created. gap = 0.5 ; # Maximum length of segment (s). maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. +; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = +; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. +; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. +; # If you intend to use a multicharacter string for gap_sym, remember to register the +; # multicharacter string as part of userdef-string in prepare_lang_char.py. +gap_sym = [CONSTANTS] ; # Name of this mode @@ -115,59 +115,59 @@ B^ = 0 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', +; # Example: 'コク(笑 サイ+(D オン))', 笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete -; # Example: '(泣 ドンナニ)' +; # Example: '(泣 ドンナニ)' 泣 = 0 泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete -; # Example: 'シャ(咳 リン) ノ' +; # Example: 'シャ(咳 リン) ノ' 咳 = 0 ; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' 咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete -; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' +; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 ; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' L^ = 0 [REPLACEMENTS] ; # ボーカルフライなどで母音が同定できない場合 - = + = ; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = + = ; # 非語彙的な母音の引き延ばし - = + = ; # 非語彙的な子音の引き延ばし - = + = ; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = +<笑> = ; # 言語音と独立に講演者の咳が生じている場合 -<咳> = +<咳> = ; # 言語音と独立に講演者の息が生じている場合 -<息> = +<息> = ; # 講演者の泣き声 -<泣> = +<泣> = ; # 聴衆(司会者なども含む)の発話 -<フロア発話> = +<フロア発話> = ; # 聴衆の笑い -<フロア笑> = +<フロア笑> = ; # 聴衆の拍手 -<拍手> = +<拍手> = ; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = +<デモ> = ; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = +<ベル> = ; # 転記単位全体が再度読み直された場合 -<朗読間違い> = +<朗読間違い> = ; # 上記以外の音で特に目立った音 -<雑音> = +<雑音> = ; # 0.2秒以上のポーズ -

= +

= ; # Redacted information, for R ; # It is \x00D7 multiplication sign, not your normal 'x' × = × @@ -318,4 +318,3 @@ spk_id = 2 ャ = ǐa ュ = ǐu ョ = ǐo - diff --git a/egs/csj/ASR/local/conf/number.ini b/egs/csj/ASR/local/conf/number.ini index 2613c3409..cf9038f62 100644 --- a/egs/csj/ASR/local/conf/number.ini +++ b/egs/csj/ASR/local/conf/number.ini @@ -1,17 +1,17 @@ ; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj +; # lhotse prepare csj [SEGMENTS] ; # Allowed period of nonverbal noise. If exceeded, a new segment is created. gap = 0.5 ; # Maximum length of segment (s). maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. +; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = +; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. +; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. +; # If you intend to use a multicharacter string for gap_sym, remember to register the +; # multicharacter string as part of userdef-string in prepare_lang_char.py. +gap_sym = [CONSTANTS] ; # Name of this mode @@ -115,59 +115,59 @@ B^ = 0 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', +; # Example: 'コク(笑 サイ+(D オン))', 笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete -; # Example: '(泣 ドンナニ)' +; # Example: '(泣 ドンナニ)' 泣 = 0 泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete -; # Example: 'シャ(咳 リン) ノ' +; # Example: 'シャ(咳 リン) ノ' 咳 = 0 ; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' 咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete -; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' +; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 ; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' L^ = 0 [REPLACEMENTS] ; # ボーカルフライなどで母音が同定できない場合 - = + = ; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = + = ; # 非語彙的な母音の引き延ばし - = + = ; # 非語彙的な子音の引き延ばし - = + = ; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = +<笑> = ; # 言語音と独立に講演者の咳が生じている場合 -<咳> = +<咳> = ; # 言語音と独立に講演者の息が生じている場合 -<息> = +<息> = ; # 講演者の泣き声 -<泣> = +<泣> = ; # 聴衆(司会者なども含む)の発話 -<フロア発話> = +<フロア発話> = ; # 聴衆の笑い -<フロア笑> = +<フロア笑> = ; # 聴衆の拍手 -<拍手> = +<拍手> = ; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = +<デモ> = ; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = +<ベル> = ; # 転記単位全体が再度読み直された場合 -<朗読間違い> = +<朗読間違い> = ; # 上記以外の音で特に目立った音 -<雑音> = +<雑音> = ; # 0.2秒以上のポーズ -

= +

= ; # Redacted information, for R ; # It is \x00D7 multiplication sign, not your normal 'x' × = × @@ -318,4 +318,3 @@ spk_id = 2 ャ = ǐa ュ = ǐu ョ = ǐo - diff --git a/egs/csj/ASR/local/conf/symbol.ini b/egs/csj/ASR/local/conf/symbol.ini index 8ba451dd5..f9801284b 100644 --- a/egs/csj/ASR/local/conf/symbol.ini +++ b/egs/csj/ASR/local/conf/symbol.ini @@ -1,17 +1,17 @@ ; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj +; # lhotse prepare csj [SEGMENTS] ; # Allowed period of nonverbal noise. If exceeded, a new segment is created. gap = 0.5 ; # Maximum length of segment (s). maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. +; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = +; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. +; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. +; # If you intend to use a multicharacter string for gap_sym, remember to register the +; # multicharacter string as part of userdef-string in prepare_lang_char.py. +gap_sym = [CONSTANTS] ; # Name of this mode @@ -116,59 +116,59 @@ B^ = 0 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', +; # Example: 'コク(笑 サイ+(D オン))', 笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete -; # Example: '(泣 ドンナニ)' +; # Example: '(泣 ドンナニ)' 泣 = 0 泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete -; # Example: 'シャ(咳 リン) ノ' +; # Example: 'シャ(咳 リン) ノ' 咳 = 0 ; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' 咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete -; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' +; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 ; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' L^ = 0 [REPLACEMENTS] ; # ボーカルフライなどで母音が同定できない場合 - = + = ; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = + = ; # 非語彙的な母音の引き延ばし - = + = ; # 非語彙的な子音の引き延ばし - = + = ; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = +<笑> = ; # 言語音と独立に講演者の咳が生じている場合 -<咳> = +<咳> = ; # 言語音と独立に講演者の息が生じている場合 -<息> = +<息> = ; # 講演者の泣き声 -<泣> = +<泣> = ; # 聴衆(司会者なども含む)の発話 -<フロア発話> = +<フロア発話> = ; # 聴衆の笑い -<フロア笑> = +<フロア笑> = ; # 聴衆の拍手 -<拍手> = +<拍手> = ; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = +<デモ> = ; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = +<ベル> = ; # 転記単位全体が再度読み直された場合 -<朗読間違い> = +<朗読間違い> = ; # 上記以外の音で特に目立った音 -<雑音> = +<雑音> = ; # 0.2秒以上のポーズ -

= +

= ; # Redacted information, for R ; # It is \x00D7 multiplication sign, not your normal 'x' × = × @@ -319,4 +319,3 @@ spk_id = 2 ャ = ǐa ュ = ǐu ョ = ǐo - diff --git a/egs/csj/ASR/local/display_manifest_statistics.py b/egs/csj/ASR/local/display_manifest_statistics.py index c9de21073..c043cf853 100644 --- a/egs/csj/ASR/local/display_manifest_statistics.py +++ b/egs/csj/ASR/local/display_manifest_statistics.py @@ -37,9 +37,7 @@ def get_parser(): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument( - "--manifest-dir", type=Path, help="Path to cutset manifests" - ) + parser.add_argument("--manifest-dir", type=Path, help="Path to cutset manifests") return parser.parse_args() diff --git a/egs/csj/ASR/local/prepare_lang_char.py b/egs/csj/ASR/local/prepare_lang_char.py index e4d996871..f0078421b 100644 --- a/egs/csj/ASR/local/prepare_lang_char.py +++ b/egs/csj/ASR/local/prepare_lang_char.py @@ -68,8 +68,7 @@ def get_args(): type=Path, default=None, help=( - "Name of lang dir. " - "If not set, this will default to lang_char_{trans-mode}" + "Name of lang dir. If not set, this will default to lang_char_{trans-mode}" ), ) @@ -87,9 +86,7 @@ def main(): args = get_args() logging.basicConfig( - format=( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] " "%(message)s" - ), + format="%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s", level=logging.INFO, ) @@ -111,8 +108,7 @@ def main(): words = set() logging.info( - f"Creating vocabulary from {args.train_cut.name}" - f" at {args.trans_mode} mode." + f"Creating vocabulary from {args.train_cut.name} at {args.trans_mode} mode." ) for cut in train_set: try: @@ -123,8 +119,7 @@ def main(): ) except KeyError: raise KeyError( - f"Could not find {args.trans_mode} in " - f"{cut.supervisions[0].custom}" + f"Could not find {args.trans_mode} in {cut.supervisions[0].custom}" ) for t in text.split(): if t in args.userdef_string: @@ -143,9 +138,7 @@ def main(): (args.lang_dir / "words_len").write_text(f"{len(words)}") - (args.lang_dir / "userdef_string").write_text( - "\n".join(args.userdef_string) - ) + (args.lang_dir / "userdef_string").write_text("\n".join(args.userdef_string)) (args.lang_dir / "trans_mode").write_text(args.trans_mode) logging.info("Done.") diff --git a/egs/csj/ASR/local/validate_manifest.py b/egs/csj/ASR/local/validate_manifest.py index 0c4c6c1ea..89448a49c 100644 --- a/egs/csj/ASR/local/validate_manifest.py +++ b/egs/csj/ASR/local/validate_manifest.py @@ -68,8 +68,7 @@ def validate_supervision_and_cut_time_bounds(c: Cut): if s.end > c.end: raise ValueError( - f"{c.id}: Supervision end time {s.end} is larger " - f"than cut end time {c.end}" + f"{c.id}: Supervision end time {s.end} is larger than cut end time {c.end}" ) @@ -89,9 +88,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py index d78e26240..c3e3e84bf 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py +++ b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py @@ -61,10 +61,12 @@ class GigaSpeechAsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", + description=( + "These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc." + ), ) group.add_argument( "--manifest-dir", @@ -76,75 +78,91 @@ class GigaSpeechAsrDataModule: "--max-duration", type=int, default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", + help=( + "Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM." + ), ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", + help=( + "When enabled, the batches will come from buckets of " + "similar duration (saves padding frames)." + ), ) group.add_argument( "--num-buckets", type=int, default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", + help=( + "The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets)." + ), ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", + help=( + "When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding." + ), ) group.add_argument( "--duration-factor", type=float, default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", + help=( + "Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch." + ), ) group.add_argument( "--gap", type=float, default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", + help=( + "The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used." + ), ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", + help=( + "When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available." + ), ) group.add_argument( "--shuffle", type=str2bool, default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", + help=( + "When enabled (=default), the examples will be shuffled for each epoch." + ), ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", + help=( + "When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it." + ), ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that " - "collect the batches.", + help="The number of training dataloader workers that collect the batches.", ) group.add_argument( @@ -158,18 +176,22 @@ class GigaSpeechAsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", + help=( + "Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp." + ), ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help="When enabled, select noise from MUSAN and mix it " - "with training dataset. ", + help=( + "When enabled, select noise from MUSAN and mix it " + "with training dataset. " + ), ) # GigaSpeech specific arguments @@ -183,30 +205,25 @@ class GigaSpeechAsrDataModule: "--small-dev", type=str2bool, default=False, - help="Should we use only 1000 utterances for dev " - "(speeds up training)", + help="Should we use only 1000 utterances for dev (speeds up training)", ) def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - f"Using cut concatenation with duration factor " + "Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -221,9 +238,7 @@ class GigaSpeechAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") input_transforms.append( SpecAugment( time_warp_factor=self.args.spec_aug_time_warp_factor, @@ -256,9 +271,7 @@ class GigaSpeechAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -304,9 +317,7 @@ class GigaSpeechAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: @@ -362,9 +373,7 @@ class GigaSpeechAsrDataModule: @lru_cache() def dev_cuts(self) -> CutSet: logging.info("About to get dev cuts") - cuts_valid = load_manifest_lazy( - self.args.manifest_dir / "cuts_DEV.jsonl.gz" - ) + cuts_valid = load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") if self.args.small_dev: return cuts_valid.subset(first=1000) else: diff --git a/egs/gigaspeech/ASR/conformer_ctc/conformer.py b/egs/gigaspeech/ASR/conformer_ctc/conformer.py index 6fac07f93..1153a814c 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/conformer.py +++ b/egs/gigaspeech/ASR/conformer_ctc/conformer.py @@ -160,9 +160,7 @@ class ConformerEncoderLayer(nn.Module): use_conv_batchnorm: bool = False, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -182,18 +180,14 @@ class ConformerEncoderLayer(nn.Module): d_model, cnn_module_kernel, use_batchnorm=use_conv_batchnorm ) - self.norm_ff_macaron = nn.LayerNorm( - d_model - ) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm( - d_model - ) # for the final output of the block + self.norm_final = nn.LayerNorm(d_model) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -227,9 +221,7 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) - ) + src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -348,9 +340,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -366,9 +356,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -638,9 +626,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -708,33 +696,25 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor" + " instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -771,9 +751,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -785,9 +763,7 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -821,13 +797,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads diff --git a/egs/gigaspeech/ASR/conformer_ctc/decode.py b/egs/gigaspeech/ASR/conformer_ctc/decode.py index 9c1418baa..b38ae9c8c 100755 --- a/egs/gigaspeech/ASR/conformer_ctc/decode.py +++ b/egs/gigaspeech/ASR/conformer_ctc/decode.py @@ -62,16 +62,19 @@ def get_parser(): "--epoch", type=int, default=0, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=1, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -476,9 +479,7 @@ def decode_dataset( results[lm_scale].extend(this_batch) else: - assert ( - len(results) > 0 - ), "It should not decode to empty in the first batch!" + assert len(results) > 0, "It should not decode to empty in the first batch!" this_batch = [] hyp_words = [] for cut_id, ref_text in zip(cut_ids, texts): @@ -493,9 +494,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -528,9 +527,7 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -705,9 +702,7 @@ def main(): eos_id=eos_id, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py b/egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py index ef53b77f8..880aa76e2 100755 --- a/egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py +++ b/egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py @@ -73,8 +73,7 @@ def asr_text_post_processing(text: str) -> str: if __name__ == "__main__": parser = argparse.ArgumentParser( - description="This script evaluates GigaSpeech ASR result via" - "SCTK's tool sclite" + description="This script evaluates GigaSpeech ASR result viaSCTK's tool sclite" ) parser.add_argument( "ref", diff --git a/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py b/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py index cdc85ce9a..3b94f0c4b 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py +++ b/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py @@ -78,13 +78,10 @@ class LabelSmoothingLoss(torch.nn.Module): ignored = target == self.ignore_index target[ignored] = 0 - true_dist = torch.nn.functional.one_hot( - target, num_classes=num_classes - ).to(x) + true_dist = torch.nn.functional.one_hot(target, num_classes=num_classes).to(x) true_dist = ( - true_dist * (1 - self.label_smoothing) - + self.label_smoothing / num_classes + true_dist * (1 - self.label_smoothing) + self.label_smoothing / num_classes ) # Set the value of ignored indexes to 0 true_dist[ignored] = 0 diff --git a/egs/gigaspeech/ASR/conformer_ctc/subsampling.py b/egs/gigaspeech/ASR/conformer_ctc/subsampling.py index 542fb0364..8e0f73d05 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/subsampling.py +++ b/egs/gigaspeech/ASR/conformer_ctc/subsampling.py @@ -42,13 +42,9 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), - nn.Conv2d( - in_channels=odim, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -132,17 +128,13 @@ class VggSubsampling(nn.Module): ) ) layers.append( - torch.nn.MaxPool2d( - kernel_size=2, stride=2, padding=0, ceil_mode=True - ) + torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear( - block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim - ) + self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/gigaspeech/ASR/conformer_ctc/train.py b/egs/gigaspeech/ASR/conformer_ctc/train.py index 2965cde18..4883d04d8 100755 --- a/egs/gigaspeech/ASR/conformer_ctc/train.py +++ b/egs/gigaspeech/ASR/conformer_ctc/train.py @@ -386,9 +386,7 @@ def compute_loss( # # See https://github.com/k2-fsa/icefall/issues/97 # for more details - unsorted_token_ids = graph_compiler.texts_to_ids( - supervisions["text"] - ) + unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) att_loss = mmodel.decoder_forward( encoder_memory, memory_mask, @@ -521,9 +519,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -641,9 +637,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/gigaspeech/ASR/conformer_ctc/transformer.py b/egs/gigaspeech/ASR/conformer_ctc/transformer.py index 00ca027a7..0566cfc81 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/transformer.py +++ b/egs/gigaspeech/ASR/conformer_ctc/transformer.py @@ -151,9 +151,7 @@ class Transformer(nn.Module): norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear( - d_model, self.decoder_num_class - ) + self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) self.decoder_criterion = LabelSmoothingLoss() else: @@ -181,18 +179,13 @@ class Transformer(nn.Module): memory_key_padding_mask for the decoder. Its shape is (N, T). It is None if `supervision` is None. """ - if ( - isinstance(self.use_feat_batchnorm, bool) - and self.use_feat_batchnorm - ): + if isinstance(self.use_feat_batchnorm, bool) and self.use_feat_batchnorm: x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) if isinstance(self.use_feat_batchnorm, float): x *= self.use_feat_batchnorm - encoder_memory, memory_key_padding_mask = self.run_encoder( - x, supervision - ) + encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -273,23 +266,17 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -350,23 +337,17 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -639,9 +620,7 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) class PositionalEncoding(nn.Module): @@ -843,9 +822,7 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [ - 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) - ] + lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -866,9 +843,7 @@ def encoder_padding_mask( return mask -def decoder_padding_mask( - ys_pad: torch.Tensor, ignore_id: int = -1 -) -> torch.Tensor: +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py index 8209ee3ec..07beeb1f0 100755 --- a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py +++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py @@ -77,9 +77,7 @@ def compute_fbank_gigaspeech_dev_test(): def main(): - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_gigaspeech_dev_test() diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py index 6410249db..0ee845ec8 100755 --- a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py +++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py @@ -47,8 +47,10 @@ def get_parser(): "--batch-duration", type=float, default=600.0, - help="The maximum number of audio seconds in a batch." - "Determines batch size dynamically.", + help=( + "The maximum number of audio seconds in a batch." + "Determines batch size dynamically." + ), ) parser.add_argument( @@ -134,9 +136,7 @@ def main(): date_time = now.strftime("%Y-%m-%d-%H-%M-%S") log_filename = "log-compute_fbank_gigaspeech_splits" - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" log_filename = f"{log_filename}-{date_time}" logging.basicConfig( diff --git a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py index 48d10a157..31abe7fff 100755 --- a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py +++ b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py @@ -98,19 +98,13 @@ def preprocess_giga_speech(): f"Speed perturb for {partition} with factors 0.9 and 1.1 " "(Perturbing may take 8 minutes and saving may take 20 minutes)" ) - cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) - ) + cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) logging.info(f"Saving to {raw_cuts_path}") cut_set.to_file(raw_cuts_path) def main(): - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) preprocess_giga_speech() diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index c87686e1e..9ae3f071e 100644 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -73,10 +73,12 @@ class GigaSpeechAsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", + description=( + "These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc." + ), ) group.add_argument( "--manifest-dir", @@ -88,75 +90,91 @@ class GigaSpeechAsrDataModule: "--max-duration", type=int, default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", + help=( + "Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM." + ), ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", + help=( + "When enabled, the batches will come from buckets of " + "similar duration (saves padding frames)." + ), ) group.add_argument( "--num-buckets", type=int, default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", + help=( + "The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets)." + ), ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", + help=( + "When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding." + ), ) group.add_argument( "--duration-factor", type=float, default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", + help=( + "Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch." + ), ) group.add_argument( "--gap", type=float, default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", + help=( + "The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used." + ), ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", + help=( + "When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available." + ), ) group.add_argument( "--shuffle", type=str2bool, default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", + help=( + "When enabled (=default), the examples will be shuffled for each epoch." + ), ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", + help=( + "When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it." + ), ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that " - "collect the batches.", + help="The number of training dataloader workers that collect the batches.", ) group.add_argument( @@ -170,18 +188,22 @@ class GigaSpeechAsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", + help=( + "Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp." + ), ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help="When enabled, select noise from MUSAN and mix it " - "with training dataset. ", + help=( + "When enabled, select noise from MUSAN and mix it " + "with training dataset. " + ), ) # GigaSpeech specific arguments @@ -195,8 +217,7 @@ class GigaSpeechAsrDataModule: "--small-dev", type=str2bool, default=False, - help="Should we use only 1000 utterances for dev " - "(speeds up training)", + help="Should we use only 1000 utterances for dev (speeds up training)", ) def train_dataloaders( @@ -216,20 +237,16 @@ class GigaSpeechAsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - f"Using cut concatenation with duration factor " + "Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -244,9 +261,7 @@ class GigaSpeechAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -289,9 +304,7 @@ class GigaSpeechAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -347,9 +360,7 @@ class GigaSpeechAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: @@ -405,9 +416,7 @@ class GigaSpeechAsrDataModule: @lru_cache() def dev_cuts(self) -> CutSet: logging.info("About to get dev cuts") - cuts_valid = load_manifest_lazy( - self.args.manifest_dir / "cuts_DEV.jsonl.gz" - ) + cuts_valid = load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") if self.args.small_dev: return cuts_valid.subset(first=1000) else: diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py index 5849a3471..9f5d4711b 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py @@ -77,11 +77,7 @@ from beam_search import ( from gigaspeech_scoring import asr_text_post_processing from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import ( AttributeDict, setup_logger, @@ -118,9 +114,11 @@ def get_parser(): "--avg", type=int, default=8, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -188,8 +186,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -258,9 +255,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -275,10 +270,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -324,11 +316,7 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -398,9 +386,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -434,8 +420,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -511,8 +496,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py index cff9c7377..17f8614dc 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py @@ -51,11 +51,7 @@ import sentencepiece as spm import torch from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import str2bool @@ -87,9 +83,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -120,8 +118,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) return parser @@ -160,8 +157,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -209,9 +205,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py index 83ae25561..4d1a2356d 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py @@ -77,9 +77,7 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def get_parser(): @@ -178,42 +176,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -553,23 +554,16 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -732,9 +726,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") diff --git a/egs/librispeech/ASR/conformer_ctc/ali.py b/egs/librispeech/ASR/conformer_ctc/ali.py index 2828e309e..0169d0f82 100755 --- a/egs/librispeech/ASR/conformer_ctc/ali.py +++ b/egs/librispeech/ASR/conformer_ctc/ali.py @@ -61,16 +61,19 @@ def get_parser(): "--epoch", type=int, default=34, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=20, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -231,9 +234,7 @@ def compute_alignments( labels_ali = get_alignments(best_path, kind="labels") aux_labels_ali = get_alignments(best_path, kind="aux_labels") assert len(labels_ali) == len(aux_labels_ali) == len(cut_list) - for cut, labels, aux_labels in zip( - cut_list, labels_ali, aux_labels_ali - ): + for cut, labels, aux_labels in zip(cut_list, labels_ali, aux_labels_ali): cut.labels_alignment = labels_writer.store_array( key=cut.id, value=np.asarray(labels, dtype=np.int32), @@ -258,9 +259,7 @@ def compute_alignments( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return CutSet.from_cuts(cuts) @@ -289,9 +288,7 @@ def main(): out_labels_ali_filename = out_dir / f"labels_{params.dataset}.h5" out_aux_labels_ali_filename = out_dir / f"aux_labels_{params.dataset}.h5" - out_manifest_filename = ( - out_dir / f"librispeech_cuts_{params.dataset}.jsonl.gz" - ) + out_manifest_filename = out_dir / f"librispeech_cuts_{params.dataset}.jsonl.gz" for f in ( out_labels_ali_filename, diff --git a/egs/librispeech/ASR/conformer_ctc/conformer.py b/egs/librispeech/ASR/conformer_ctc/conformer.py index 6fac07f93..1153a814c 100644 --- a/egs/librispeech/ASR/conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc/conformer.py @@ -160,9 +160,7 @@ class ConformerEncoderLayer(nn.Module): use_conv_batchnorm: bool = False, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -182,18 +180,14 @@ class ConformerEncoderLayer(nn.Module): d_model, cnn_module_kernel, use_batchnorm=use_conv_batchnorm ) - self.norm_ff_macaron = nn.LayerNorm( - d_model - ) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm( - d_model - ) # for the final output of the block + self.norm_final = nn.LayerNorm(d_model) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -227,9 +221,7 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) - ) + src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -348,9 +340,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -366,9 +356,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -638,9 +626,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -708,33 +696,25 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor" + " instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -771,9 +751,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -785,9 +763,7 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -821,13 +797,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 3f3b1acda..66fdf82d9 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -64,16 +64,19 @@ def get_parser(): "--epoch", type=int, default=77, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=55, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -551,9 +554,7 @@ def decode_dataset( results[lm_scale].extend(this_batch) else: - assert ( - len(results) > 0 - ), "It should not decode to empty in the first batch!" + assert len(results) > 0, "It should not decode to empty in the first batch!" this_batch = [] hyp_words = [] for ref_text in texts: @@ -568,9 +569,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -602,9 +601,7 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -809,9 +806,7 @@ def main(): eos_id=eos_id, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/librispeech/ASR/conformer_ctc/export.py b/egs/librispeech/ASR/conformer_ctc/export.py index 28c28df01..bdb8a85e5 100755 --- a/egs/librispeech/ASR/conformer_ctc/export.py +++ b/egs/librispeech/ASR/conformer_ctc/export.py @@ -40,17 +40,20 @@ def get_parser(): "--epoch", type=int, default=34, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=20, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -157,9 +160,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conformer_ctc/label_smoothing.py b/egs/librispeech/ASR/conformer_ctc/label_smoothing.py index 1f2f3b137..cb0d6e04d 100644 --- a/egs/librispeech/ASR/conformer_ctc/label_smoothing.py +++ b/egs/librispeech/ASR/conformer_ctc/label_smoothing.py @@ -82,13 +82,10 @@ class LabelSmoothingLoss(torch.nn.Module): # for why we don't use target[ignored] = 0 here target = torch.where(ignored, torch.zeros_like(target), target) - true_dist = torch.nn.functional.one_hot( - target, num_classes=num_classes - ).to(x) + true_dist = torch.nn.functional.one_hot(target, num_classes=num_classes).to(x) true_dist = ( - true_dist * (1 - self.label_smoothing) - + self.label_smoothing / num_classes + true_dist * (1 - self.label_smoothing) + self.label_smoothing / num_classes ) # Set the value of ignored indexes to 0 diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index a2c0a5486..8cabf1a53 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -48,9 +48,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -189,10 +191,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) return parser @@ -236,10 +240,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -300,9 +303,7 @@ def main(): logging.info("Decoding started") features = fbank(waves) - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) # Note: We don't use key padding mask for attention during decoding with torch.no_grad(): @@ -427,9 +428,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 542fb0364..8e0f73d05 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -42,13 +42,9 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), - nn.Conv2d( - in_channels=odim, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -132,17 +128,13 @@ class VggSubsampling(nn.Module): ) ) layers.append( - torch.nn.MaxPool2d( - kernel_size=2, stride=2, padding=0, ceil_mode=True - ) + torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear( - block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim - ) + self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 6419f6816..1a1c2f4c5 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -393,9 +393,7 @@ def compute_loss( # Works with a phone lexicon decoding_graph = graph_compiler.compile(texts) else: - raise ValueError( - f"Unsupported type of graph compiler: {type(graph_compiler)}" - ) + raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}") dense_fsa_vec = k2.DenseFsaVec( nnet_output, @@ -422,9 +420,7 @@ def compute_loss( # # See https://github.com/k2-fsa/icefall/issues/97 # for more details - unsorted_token_ids = graph_compiler.texts_to_ids( - supervisions["text"] - ) + unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) att_loss = mmodel.decoder_forward( encoder_memory, memory_mask, @@ -453,9 +449,7 @@ def compute_loss( info["utt_duration"] = supervisions["num_frames"].sum().item() # averaged padding proportion over utterances info["utt_pad_proportion"] = ( - ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)) - .sum() - .item() + ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)).sum().item() ) return loss, info @@ -568,9 +562,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -660,7 +652,7 @@ def run(rank, world_size, args): graph_compiler.eos_id = 1 else: raise ValueError( - f"Unsupported type of lang dir (we expected it to have " + "Unsupported type of lang dir (we expected it to have " f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}" ) @@ -733,9 +725,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/conformer_ctc/transformer.py b/egs/librispeech/ASR/conformer_ctc/transformer.py index 00ca027a7..0566cfc81 100644 --- a/egs/librispeech/ASR/conformer_ctc/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc/transformer.py @@ -151,9 +151,7 @@ class Transformer(nn.Module): norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear( - d_model, self.decoder_num_class - ) + self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) self.decoder_criterion = LabelSmoothingLoss() else: @@ -181,18 +179,13 @@ class Transformer(nn.Module): memory_key_padding_mask for the decoder. Its shape is (N, T). It is None if `supervision` is None. """ - if ( - isinstance(self.use_feat_batchnorm, bool) - and self.use_feat_batchnorm - ): + if isinstance(self.use_feat_batchnorm, bool) and self.use_feat_batchnorm: x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) if isinstance(self.use_feat_batchnorm, float): x *= self.use_feat_batchnorm - encoder_memory, memory_key_padding_mask = self.run_encoder( - x, supervision - ) + encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -273,23 +266,17 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -350,23 +337,17 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -639,9 +620,7 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) class PositionalEncoding(nn.Module): @@ -843,9 +822,7 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [ - 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) - ] + lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -866,9 +843,7 @@ def encoder_padding_mask( return mask -def decoder_padding_mask( - ys_pad: torch.Tensor, ignore_id: int = -1 -) -> torch.Tensor: +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/librispeech/ASR/conformer_ctc2/attention.py b/egs/librispeech/ASR/conformer_ctc2/attention.py index 1375d7245..356d3f21b 100644 --- a/egs/librispeech/ASR/conformer_ctc2/attention.py +++ b/egs/librispeech/ASR/conformer_ctc2/attention.py @@ -18,11 +18,10 @@ from typing import Optional, Tuple import torch import torch.nn as nn +from scaling import ScaledLinear from torch import Tensor from torch.nn.init import xavier_normal_ -from scaling import ScaledLinear - class MultiheadAttention(nn.Module): r"""Allows the model to jointly attend to information @@ -76,9 +75,7 @@ class MultiheadAttention(nn.Module): self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim - self._qkv_same_embed_dim = ( - self.kdim == embed_dim and self.vdim == embed_dim - ) + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim self.num_heads = num_heads self.dropout = dropout @@ -94,9 +91,7 @@ class MultiheadAttention(nn.Module): self.v_proj_weight = ScaledLinear(self.vdim, embed_dim, bias=bias) self.register_parameter("in_proj_weight", None) else: - self.in_proj_weight = ScaledLinear( - embed_dim, 3 * embed_dim, bias=bias - ) + self.in_proj_weight = ScaledLinear(embed_dim, 3 * embed_dim, bias=bias) self.register_parameter("q_proj_weight", None) self.register_parameter("k_proj_weight", None) self.register_parameter("v_proj_weight", None) @@ -107,12 +102,8 @@ class MultiheadAttention(nn.Module): self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=bias) if add_bias_kv: - self.bias_k = nn.Parameter( - torch.empty((1, 1, embed_dim), **factory_kwargs) - ) - self.bias_v = nn.Parameter( - torch.empty((1, 1, embed_dim), **factory_kwargs) - ) + self.bias_k = nn.Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) + self.bias_v = nn.Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) else: self.bias_k = self.bias_v = None diff --git a/egs/librispeech/ASR/conformer_ctc2/conformer.py b/egs/librispeech/ASR/conformer_ctc2/conformer.py index b906d2650..a6f1679ef 100644 --- a/egs/librispeech/ASR/conformer_ctc2/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc2/conformer.py @@ -29,9 +29,8 @@ from scaling import ( ScaledConv1d, ScaledLinear, ) -from torch import Tensor, nn from subsampling import Conv2dSubsampling - +from torch import Tensor, nn from transformer import Supervisions, Transformer, encoder_padding_mask @@ -182,9 +181,7 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -356,9 +353,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -373,9 +368,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vecotr and `j` means the @@ -650,9 +643,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -721,33 +714,25 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor" + " instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -784,9 +769,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -794,13 +777,9 @@ class RelPositionMultiheadAttention(nn.Module): ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -834,13 +813,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -863,9 +838,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/conformer_ctc2/decode.py b/egs/librispeech/ASR/conformer_ctc2/decode.py index 97f2f2d39..934177b1f 100755 --- a/egs/librispeech/ASR/conformer_ctc2/decode.py +++ b/egs/librispeech/ASR/conformer_ctc2/decode.py @@ -90,9 +90,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -130,11 +132,13 @@ def get_parser(): "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -658,9 +662,7 @@ def decode_dataset( results[lm_scale].extend(this_batch) else: - assert ( - len(results) > 0 - ), "It should not decode to empty in the first batch!" + assert len(results) > 0, "It should not decode to empty in the first batch!" this_batch = [] hyp_words = [] for ref_text in texts: @@ -675,9 +677,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -709,9 +709,7 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -852,13 +850,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -881,13 +878,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -915,7 +911,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -985,9 +981,7 @@ def main(): eos_id=eos_id, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/librispeech/ASR/conformer_ctc2/export.py b/egs/librispeech/ASR/conformer_ctc2/export.py index 584b3c3fc..0e1841d8d 100755 --- a/egs/librispeech/ASR/conformer_ctc2/export.py +++ b/egs/librispeech/ASR/conformer_ctc2/export.py @@ -47,6 +47,7 @@ import logging from pathlib import Path import torch +from conformer import Conformer from decode import get_params from icefall.checkpoint import ( @@ -55,10 +56,8 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from conformer import Conformer - -from icefall.utils import str2bool from icefall.lexicon import Lexicon +from icefall.utils import str2bool def get_parser(): @@ -89,20 +88,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -177,13 +180,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -206,13 +208,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -240,7 +241,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -273,9 +274,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conformer_ctc2/train.py b/egs/librispeech/ASR/conformer_ctc2/train.py index 9d9c2af1f..63534b76b 100755 --- a/egs/librispeech/ASR/conformer_ctc2/train.py +++ b/egs/librispeech/ASR/conformer_ctc2/train.py @@ -69,8 +69,8 @@ from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter -from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall import diagnostics +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import ( @@ -89,9 +89,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def get_parser(): @@ -505,11 +503,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -546,9 +540,7 @@ def compute_loss( # Works with a phone lexicon decoding_graph = graph_compiler.compile(texts) else: - raise ValueError( - f"Unsupported type of graph compiler: {type(graph_compiler)}" - ) + raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}") dense_fsa_vec = k2.DenseFsaVec( nnet_output, @@ -575,9 +567,7 @@ def compute_loss( # # See https://github.com/k2-fsa/icefall/issues/97 # for more details - unsorted_token_ids = graph_compiler.texts_to_ids( - supervisions["text"] - ) + unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) att_loss = mmodel.decoder_forward( encoder_memory, memory_mask, @@ -595,9 +585,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() info["ctc_loss"] = ctc_loss.detach().cpu().item() if params.att_rate != 0.0: info["att_loss"] = att_loss.detach().cpu().item() @@ -735,8 +723,7 @@ def train_one_epoch( except RuntimeError as e: if "CUDA out of memory" in str(e): logging.error( - f"failing batch size:{batch_size} " - f"failing batch names {batch_name}" + f"failing batch size:{batch_size} failing batch names {batch_name}" ) raise @@ -791,9 +778,9 @@ def train_one_epoch( f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"lr: {cur_lr:.2e}" ) - if loss_info["ctc_loss"] == float("inf") or loss_info[ - "att_loss" - ] == float("inf"): + if loss_info["ctc_loss"] == float("inf") or loss_info["att_loss"] == float( + "inf" + ): logging.error( "Your loss contains inf, something goes wrong" f"failing batch names {batch_name}" @@ -806,9 +793,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -900,7 +885,7 @@ def run(rank, world_size, args): graph_compiler.eos_id = 1 else: raise ValueError( - f"Unsupported type of lang dir (we expected it to have " + "Unsupported type of lang dir (we expected it to have " f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}" ) diff --git a/egs/librispeech/ASR/conformer_ctc2/transformer.py b/egs/librispeech/ASR/conformer_ctc2/transformer.py index fa179acc0..8f0c7dcde 100644 --- a/egs/librispeech/ASR/conformer_ctc2/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc2/transformer.py @@ -21,19 +21,17 @@ from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn -from label_smoothing import LabelSmoothingLoss -from subsampling import Conv2dSubsampling from attention import MultiheadAttention -from torch.nn.utils.rnn import pad_sequence - +from label_smoothing import LabelSmoothingLoss from scaling import ( ActivationBalancer, BasicNorm, DoubleSwish, - ScaledLinear, ScaledEmbedding, + ScaledLinear, ) - +from subsampling import Conv2dSubsampling +from torch.nn.utils.rnn import pad_sequence # Note: TorchScript requires Dict/List/etc. to be fully typed. Supervisions = Dict[str, torch.Tensor] @@ -210,9 +208,7 @@ class Transformer(nn.Module): x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) mask = encoder_padding_mask(x.size(0), supervisions) mask = mask.to(x.device) if mask is not None else None - x = self.encoder( - x, src_key_padding_mask=mask, warmup=warmup - ) # (T, N, C) + x = self.encoder(x, src_key_padding_mask=mask, warmup=warmup) # (T, N, C) return x, mask @@ -261,23 +257,17 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -338,23 +328,17 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -659,9 +643,7 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) class TransformerEncoder(nn.Module): @@ -982,9 +964,7 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [ - 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) - ] + lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -1005,9 +985,7 @@ def encoder_padding_mask( return mask -def decoder_padding_mask( - ys_pad: torch.Tensor, ignore_id: int = -1 -) -> torch.Tensor: +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/librispeech/ASR/conformer_mmi/conformer.py b/egs/librispeech/ASR/conformer_mmi/conformer.py index 97c8d83a2..4d9ddaea9 100644 --- a/egs/librispeech/ASR/conformer_mmi/conformer.py +++ b/egs/librispeech/ASR/conformer_mmi/conformer.py @@ -156,9 +156,7 @@ class ConformerEncoderLayer(nn.Module): normalize_before: bool = True, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -176,18 +174,14 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.norm_ff_macaron = nn.LayerNorm( - d_model - ) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm( - d_model - ) # for the final output of the block + self.norm_final = nn.LayerNorm(d_model) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -221,9 +215,7 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) - ) + src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -342,9 +334,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -360,9 +350,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -632,9 +620,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -702,33 +690,25 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor" + " instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -765,9 +745,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -779,9 +757,7 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -815,13 +791,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -844,9 +816,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/conformer_mmi/decode.py b/egs/librispeech/ASR/conformer_mmi/decode.py index fc9861489..e8390ded9 100755 --- a/egs/librispeech/ASR/conformer_mmi/decode.py +++ b/egs/librispeech/ASR/conformer_mmi/decode.py @@ -60,16 +60,19 @@ def get_parser(): "--epoch", type=int, default=34, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=20, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -478,9 +481,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -512,9 +513,7 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -653,9 +652,7 @@ def main(): if params.export: logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save( - {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" - ) + torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") return model.to(device) @@ -687,9 +684,7 @@ def main(): eos_id=eos_id, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/librispeech/ASR/conformer_mmi/subsampling.py b/egs/librispeech/ASR/conformer_mmi/subsampling.py index 5c3e1222e..ad9415987 100644 --- a/egs/librispeech/ASR/conformer_mmi/subsampling.py +++ b/egs/librispeech/ASR/conformer_mmi/subsampling.py @@ -25,13 +25,9 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), - nn.Conv2d( - in_channels=odim, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -115,17 +111,13 @@ class VggSubsampling(nn.Module): ) ) layers.append( - torch.nn.MaxPool2d( - kernel_size=2, stride=2, padding=0, ceil_mode=True - ) + torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear( - block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim - ) + self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/librispeech/ASR/conformer_mmi/test_subsampling.py b/egs/librispeech/ASR/conformer_mmi/test_subsampling.py index 937845d77..d0bb017dd 100755 --- a/egs/librispeech/ASR/conformer_mmi/test_subsampling.py +++ b/egs/librispeech/ASR/conformer_mmi/test_subsampling.py @@ -1,8 +1,7 @@ #!/usr/bin/env python3 -from subsampling import Conv2dSubsampling -from subsampling import VggSubsampling import torch +from subsampling import Conv2dSubsampling, VggSubsampling def test_conv2d_subsampling(): diff --git a/egs/librispeech/ASR/conformer_mmi/test_transformer.py b/egs/librispeech/ASR/conformer_mmi/test_transformer.py index 08e680607..25d18076d 100644 --- a/egs/librispeech/ASR/conformer_mmi/test_transformer.py +++ b/egs/librispeech/ASR/conformer_mmi/test_transformer.py @@ -1,17 +1,16 @@ #!/usr/bin/env python3 import torch +from torch.nn.utils.rnn import pad_sequence from transformer import ( Transformer, + add_eos, + add_sos, + decoder_padding_mask, encoder_padding_mask, generate_square_subsequent_mask, - decoder_padding_mask, - add_sos, - add_eos, ) -from torch.nn.utils.rnn import pad_sequence - def test_encoder_padding_mask(): supervisions = { diff --git a/egs/librispeech/ASR/conformer_mmi/train-with-attention.py b/egs/librispeech/ASR/conformer_mmi/train-with-attention.py index 011dadd73..f8c94cff9 100755 --- a/egs/librispeech/ASR/conformer_mmi/train-with-attention.py +++ b/egs/librispeech/ASR/conformer_mmi/train-with-attention.py @@ -36,23 +36,14 @@ from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter from transformer import Noam -from icefall.ali import ( - convert_alignments_to_tensor, - load_alignments, - lookup_alignments, -) +from icefall.ali import convert_alignments_to_tensor, load_alignments, lookup_alignments from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.lexicon import Lexicon from icefall.mmi import LFMMILoss from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler -from icefall.utils import ( - AttributeDict, - encode_supervisions, - setup_logger, - str2bool, -) +from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool def get_parser(): @@ -370,10 +361,7 @@ def compute_loss( nnet_output = nnet_output.clone() nnet_output[:, :min_len, :] += ali_scale * mask[:, :min_len, :] - if ( - params.batch_idx_train > params.use_ali_until - and params.beam_size < 8 - ): + if params.batch_idx_train > params.use_ali_until and params.beam_size < 8: # logging.info("Change beam size to 8") params.beam_size = 8 else: @@ -762,19 +750,14 @@ def run(rank, world_size, args): for epoch in range(params.start_epoch, params.num_epochs): train_dl.sampler.set_epoch(epoch) - if ( - params.batch_idx_train >= params.use_ali_until - and train_ali is not None - ): + if params.batch_idx_train >= params.use_ali_until and train_ali is not None: # Delete the alignments to save memory train_ali = None valid_ali = None cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/conformer_mmi/train.py b/egs/librispeech/ASR/conformer_mmi/train.py index 9a5bdcce2..5cfb2bfc7 100755 --- a/egs/librispeech/ASR/conformer_mmi/train.py +++ b/egs/librispeech/ASR/conformer_mmi/train.py @@ -36,23 +36,14 @@ from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter from transformer import Noam -from icefall.ali import ( - convert_alignments_to_tensor, - load_alignments, - lookup_alignments, -) +from icefall.ali import convert_alignments_to_tensor, load_alignments, lookup_alignments from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.lexicon import Lexicon from icefall.mmi import LFMMILoss from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler -from icefall.utils import ( - AttributeDict, - encode_supervisions, - setup_logger, - str2bool, -) +from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool def get_parser(): @@ -377,10 +368,7 @@ def compute_loss( nnet_output = nnet_output.clone() nnet_output[:, :min_len, :] += ali_scale * mask[:, :min_len, :] - if ( - params.batch_idx_train > params.use_ali_until - and params.beam_size < 8 - ): + if params.batch_idx_train > params.use_ali_until and params.beam_size < 8: logging.info("Change beam size to 8") params.beam_size = 8 else: @@ -770,19 +758,14 @@ def run(rank, world_size, args): for epoch in range(params.start_epoch, params.num_epochs): fix_random_seed(params.seed + epoch) train_dl.sampler.set_epoch(epoch) - if ( - params.batch_idx_train >= params.use_ali_until - and train_ali is not None - ): + if params.batch_idx_train >= params.use_ali_until and train_ali is not None: # Delete the alignments to save memory train_ali = None valid_ali = None cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/conformer_mmi/transformer.py b/egs/librispeech/ASR/conformer_mmi/transformer.py index 68a4ff65c..2542d9abe 100644 --- a/egs/librispeech/ASR/conformer_mmi/transformer.py +++ b/egs/librispeech/ASR/conformer_mmi/transformer.py @@ -148,9 +148,7 @@ class Transformer(nn.Module): norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear( - d_model, self.decoder_num_class - ) + self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) self.decoder_criterion = LabelSmoothingLoss(self.decoder_num_class) else: @@ -182,9 +180,7 @@ class Transformer(nn.Module): x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) - encoder_memory, memory_key_padding_mask = self.run_encoder( - x, supervision - ) + encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -274,9 +270,7 @@ class Transformer(nn.Module): ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -341,9 +335,7 @@ class Transformer(nn.Module): ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -616,9 +608,7 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) class PositionalEncoding(nn.Module): @@ -887,9 +877,7 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [ - 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) - ] + lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -910,9 +898,7 @@ def encoder_padding_mask( return mask -def decoder_padding_mask( - ys_pad: torch.Tensor, ignore_id: int = -1 -) -> torch.Tensor: +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py index 620d69a19..a1c43f7f5 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py @@ -135,20 +135,24 @@ def get_parser(): "--avg", type=int, default=10, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -215,8 +219,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -284,9 +287,7 @@ def decode_one_batch( value=LOG_EPS, ) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -301,10 +302,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -350,11 +348,7 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -427,9 +421,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -462,8 +454,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -506,9 +497,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -540,13 +529,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -569,13 +557,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -603,7 +590,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py index 8ca7d5568..0639ba746 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py @@ -35,7 +35,6 @@ from scaling import ( from icefall.utils import make_pad_mask - LOG_EPSILON = math.log(1e-10) @@ -127,9 +126,7 @@ def stack_states( for si, s in enumerate(layer): attn_caches[li][si].append(s) if b == batch_size - 1: - attn_caches[li][si] = torch.stack( - attn_caches[li][si], dim=1 - ) + attn_caches[li][si] = torch.stack(attn_caches[li][si], dim=1) conv_caches = [] for layer in state_list[0][1]: @@ -268,9 +265,7 @@ class ConvolutionModule(nn.Module): intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length ) - first = torch.arange( - self.chunk_length, self.chunk_length + self.cache_size - ) + first = torch.arange(self.chunk_length, self.chunk_length + self.cache_size) indexes = intervals.unsqueeze(1) + first.unsqueeze(0) indexes = torch.cat( [indexes, torch.arange(U_ - self.cache_size, U_).unsqueeze(0)] @@ -284,9 +279,7 @@ class ConvolutionModule(nn.Module): # (num_chunks * B, cache_size + right_context_length, D) return pad_right_context.permute(0, 2, 1) - def _merge_right_context( - self, right_context: torch.Tensor, B: int - ) -> torch.Tensor: + def _merge_right_context(self, right_context: torch.Tensor, B: int) -> torch.Tensor: """ Args: right_context: @@ -337,12 +330,8 @@ class ConvolutionModule(nn.Module): right_context = x[:, :, :R] # (B, D, R) # make causal convolution - cache = torch.zeros( - B, D, self.cache_size, device=x.device, dtype=x.dtype - ) - pad_utterance = torch.cat( - [cache, utterance], dim=2 - ) # (B, D, cache + U) + cache = torch.zeros(B, D, self.cache_size, device=x.device, dtype=x.dtype) + pad_utterance = torch.cat([cache, utterance], dim=2) # (B, D, cache + U) # depth-wise conv on utterance utterance = self.depthwise_conv(pad_utterance) # (B, D, U) @@ -355,9 +344,7 @@ class ConvolutionModule(nn.Module): right_context = self.depthwise_conv( pad_right_context ) # (num_segs * B, D, right_context_length) - right_context = self._merge_right_context( - right_context, B - ) # (B, D, R) + right_context = self._merge_right_context(right_context, B) # (B, D, R) x = torch.cat([right_context, utterance], dim=2) # (B, D, R + U) x = self.deriv_balancer2(x) @@ -458,8 +445,7 @@ class EmformerAttention(nn.Module): if embed_dim % nhead != 0: raise ValueError( - f"embed_dim ({embed_dim}) is not a multiple of" - f"nhead ({nhead})." + f"embed_dim ({embed_dim}) is not a multiple ofnhead ({nhead})." ) self.embed_dim = embed_dim @@ -469,9 +455,7 @@ class EmformerAttention(nn.Module): self.head_dim = embed_dim // nhead self.dropout = dropout - self.emb_to_key_value = ScaledLinear( - embed_dim, 2 * embed_dim, bias=True - ) + self.emb_to_key_value = ScaledLinear(embed_dim, 2 * embed_dim, bias=True) self.emb_to_query = ScaledLinear(embed_dim, embed_dim, bias=True) self.out_proj = ScaledLinear( embed_dim, embed_dim, bias=True, initial_scale=0.25 @@ -513,9 +497,7 @@ class EmformerAttention(nn.Module): if padding_mask is not None: Q = attention_weights.size(1) B = attention_weights.size(0) // self.nhead - attention_weights_float = attention_weights_float.view( - B, self.nhead, Q, -1 - ) + attention_weights_float = attention_weights_float.view(B, self.nhead, Q, -1) attention_weights_float = attention_weights_float.masked_fill( padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), self.negative_inf, @@ -551,9 +533,7 @@ class EmformerAttention(nn.Module): scaling = float(self.head_dim) ** -0.5 # compute query with [right_context, utterance, summary]. - query = self.emb_to_query( - torch.cat([right_context, utterance, summary]) - ) + query = self.emb_to_query(torch.cat([right_context, utterance, summary])) # compute key and value with [memory, right_context, utterance]. key, value = self.emb_to_key_value( torch.cat([memory, right_context, utterance]) @@ -564,16 +544,12 @@ class EmformerAttention(nn.Module): # [memory, right context, left context, uttrance] # this is used in inference mode key = torch.cat([key[: M + R], left_context_key, key[M + R :]]) - value = torch.cat( - [value[: M + R], left_context_val, value[M + R :]] - ) + value = torch.cat([value[: M + R], left_context_val, value[M + R :]]) Q = query.size(0) # KV = key.size(0) reshaped_query, reshaped_key, reshaped_value = [ - tensor.contiguous() - .view(-1, B * self.nhead, self.head_dim) - .transpose(0, 1) + tensor.contiguous().view(-1, B * self.nhead, self.head_dim).transpose(0, 1) for tensor in [query, key, value] ] # (B * nhead, Q or KV, head_dim) attention_weights = torch.bmm( @@ -588,9 +564,7 @@ class EmformerAttention(nn.Module): # compute attention outputs attention = torch.bmm(attention_probs, reshaped_value) assert attention.shape == (B * self.nhead, Q, self.head_dim) - attention = ( - attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) - ) + attention = attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) # apply output projection outputs = self.out_proj(attention) @@ -672,12 +646,7 @@ class EmformerAttention(nn.Module): - output of right context and utterance, with shape (R + U, B, D). - memory output, with shape (M, B, D), where M = S - 1 or M = 0. """ - ( - output_right_context_utterance, - output_memory, - _, - _, - ) = self._forward_impl( + (output_right_context_utterance, output_memory, _, _,) = self._forward_impl( utterance, right_context, summary, @@ -947,13 +916,9 @@ class EmformerEncoderLayer(nn.Module): right_context = right_context_utterance[:R] if self.use_memory: - summary = self.summary_op(utterance.permute(1, 2, 0)).permute( - 2, 0, 1 - ) + summary = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1) else: - summary = torch.empty(0).to( - dtype=utterance.dtype, device=utterance.device - ) + summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) output_right_context_utterance, output_memory = self.attention( utterance=utterance, right_context=right_context, @@ -992,14 +957,10 @@ class EmformerEncoderLayer(nn.Module): left_context_val = attn_cache[2] if self.use_memory: - summary = self.summary_op(utterance.permute(1, 2, 0)).permute( - 2, 0, 1 - ) + summary = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1) summary = summary[:1] else: - summary = torch.empty(0).to( - dtype=utterance.dtype, device=utterance.device - ) + summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) ( output_right_context_utterance, output_memory, @@ -1014,9 +975,7 @@ class EmformerEncoderLayer(nn.Module): left_context_val=left_context_val, padding_mask=padding_mask, ) - attn_cache = self._update_attn_cache( - next_key, next_val, memory, attn_cache - ) + attn_cache = self._update_attn_cache(next_key, next_val, memory, attn_cache) return output_right_context_utterance, output_memory, attn_cache def forward( @@ -1151,11 +1110,7 @@ class EmformerEncoderLayer(nn.Module): src = src + self.dropout(self.feed_forward_macaron(src)) # emformer attention module - ( - src_att, - output_memory, - attn_cache, - ) = self._apply_attention_module_infer( + (src_att, output_memory, attn_cache,) = self._apply_attention_module_infer( src, R, memory, attn_cache, padding_mask=padding_mask ) src = src + self.dropout(src_att) @@ -1295,9 +1250,7 @@ class EmformerEncoder(nn.Module): def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor: """Hard copy each chunk's right context and concat them.""" T = x.shape[0] - num_chunks = math.ceil( - (T - self.right_context_length) / self.chunk_length - ) + num_chunks = math.ceil((T - self.right_context_length) / self.chunk_length) # first (num_chunks - 1) right context block intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length @@ -1316,9 +1269,7 @@ class EmformerEncoder(nn.Module): right_context_blocks = x[indexes.reshape(-1)] return right_context_blocks - def _gen_attention_mask_col_widths( - self, chunk_idx: int, U: int - ) -> List[int]: + def _gen_attention_mask_col_widths(self, chunk_idx: int, U: int) -> List[int]: """Calculate column widths (key, value) in attention mask for the chunk_idx chunk.""" num_chunks = math.ceil(U / self.chunk_length) @@ -1479,9 +1430,7 @@ class EmformerEncoder(nn.Module): output_lengths = torch.clamp(lengths - self.right_context_length, min=0) attention_mask = self._gen_attention_mask(utterance) memory = ( - self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ - :-1 - ] + self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[:-1] if self.use_memory else torch.empty(0).to(dtype=x.dtype, device=x.device) ) @@ -1643,12 +1592,8 @@ class EmformerEncoder(nn.Module): attn_caches = [ [ torch.zeros(self.memory_size, self.d_model, device=device), - torch.zeros( - self.left_context_length, self.d_model, device=device - ), - torch.zeros( - self.left_context_length, self.d_model, device=device - ), + torch.zeros(self.left_context_length, self.d_model, device=device), + torch.zeros(self.left_context_length, self.d_model, device=device), ] for _ in range(self.num_encoder_layers) ] @@ -1693,17 +1638,11 @@ class Emformer(EncoderInterface): raise NotImplementedError( "chunk_length must be a mutiple of subsampling_factor." ) - if ( - left_context_length != 0 - and left_context_length % subsampling_factor != 0 - ): + if left_context_length != 0 and left_context_length % subsampling_factor != 0: raise NotImplementedError( "left_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) - if ( - right_context_length != 0 - and right_context_length % subsampling_factor != 0 - ): + if right_context_length != 0 and right_context_length % subsampling_factor != 0: raise NotImplementedError( "right_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) @@ -1766,9 +1705,7 @@ class Emformer(EncoderInterface): x_lens = (((x_lens - 1) >> 1) - 1) >> 1 assert x.size(0) == x_lens.max().item() - output, output_lengths = self.encoder( - x, x_lens, warmup=warmup - ) # (T, N, C) + output, output_lengths = self.encoder(x, x_lens, warmup=warmup) # (T, N, C) output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py index 4930881ea..59105e286 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py @@ -103,9 +103,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -136,19 +138,20 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) add_model_arguments(parser) @@ -181,13 +184,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -210,13 +212,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -244,7 +245,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -279,9 +280,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py index 9494e1fc1..c211b215e 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py @@ -68,14 +68,12 @@ class Stream(object): elif params.decoding_method == "fast_beam_search": # feature_len is needed to get partial results. # The rnnt_decoding_stream for fast_beam_search. - self.rnnt_decoding_stream: k2.RnntDecodingStream = ( - k2.RnntDecodingStream(decoding_graph) + self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( + decoding_graph ) self.hyp: Optional[List[int]] = None else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") self.ground_truth: str = "" diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py index 61dbe8658..abe83732a 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py @@ -113,8 +113,9 @@ def get_parser(): "--epoch", type=int, default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( @@ -131,20 +132,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -211,8 +216,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -371,9 +375,7 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner( - current_encoder_out, decoder_out, project_input=False - ) + logits = model.joiner(current_encoder_out, decoder_out, project_input=False) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -390,9 +392,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -551,14 +551,10 @@ def decode_one_chunk( feature_list, batch_first=True, padding_value=LOG_EPSILON ).to(device) feature_lens = torch.tensor(feature_len_list, device=device) - num_processed_frames = torch.tensor( - num_processed_frames_list, device=device - ) + num_processed_frames = torch.tensor(num_processed_frames_list, device=device) # Make sure it has at least 1 frame after subsampling, first-and-last-frame cutting, and right context cutting # noqa - tail_length = ( - 3 * params.subsampling_factor + params.right_context_length + 3 - ) + tail_length = 3 * params.subsampling_factor + params.right_context_length + 3 if features.size(1) < tail_length: pad_length = tail_length - features.size(1) feature_lens += pad_length @@ -605,9 +601,7 @@ def decode_one_chunk( max_states=params.max_states, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") # Update cached states of each stream state_list = unstack_states(states) @@ -782,8 +776,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -831,9 +824,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -867,13 +858,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -896,13 +886,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -930,7 +919,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py index c07d8f76b..a76417e5f 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py @@ -95,9 +95,7 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -265,42 +263,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -636,11 +637,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -668,23 +665,16 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -871,9 +861,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -981,7 +969,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py index 98b8290b5..9cb4a5afc 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py @@ -135,20 +135,24 @@ def get_parser(): "--avg", type=int, default=10, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -215,8 +219,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -284,9 +287,7 @@ def decode_one_batch( value=LOG_EPS, ) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -301,10 +302,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -350,11 +348,7 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -427,9 +421,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -462,8 +454,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -506,9 +497,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -540,13 +529,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -569,13 +557,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -603,7 +590,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py index f16f5acc7..09200f2e1 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py @@ -35,7 +35,6 @@ from scaling import ( from icefall.utils import make_pad_mask - LOG_EPSILON = math.log(1e-10) @@ -127,9 +126,7 @@ def stack_states( for si, s in enumerate(layer): attn_caches[li][si].append(s) if b == batch_size - 1: - attn_caches[li][si] = torch.stack( - attn_caches[li][si], dim=1 - ) + attn_caches[li][si] = torch.stack(attn_caches[li][si], dim=1) conv_caches = [] for layer in state_list[0][1]: @@ -268,9 +265,7 @@ class ConvolutionModule(nn.Module): intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length ) - first = torch.arange( - self.chunk_length, self.chunk_length + self.cache_size - ) + first = torch.arange(self.chunk_length, self.chunk_length + self.cache_size) indexes = intervals.unsqueeze(1) + first.unsqueeze(0) indexes = torch.cat( [indexes, torch.arange(U_ - self.cache_size, U_).unsqueeze(0)] @@ -284,9 +279,7 @@ class ConvolutionModule(nn.Module): # (num_chunks * B, cache_size + right_context_length, D) return pad_right_context.permute(0, 2, 1) - def _merge_right_context( - self, right_context: torch.Tensor, B: int - ) -> torch.Tensor: + def _merge_right_context(self, right_context: torch.Tensor, B: int) -> torch.Tensor: """ Args: right_context: @@ -337,12 +330,8 @@ class ConvolutionModule(nn.Module): right_context = x[:, :, :R] # (B, D, R) # make causal convolution - cache = torch.zeros( - B, D, self.cache_size, device=x.device, dtype=x.dtype - ) - pad_utterance = torch.cat( - [cache, utterance], dim=2 - ) # (B, D, cache + U) + cache = torch.zeros(B, D, self.cache_size, device=x.device, dtype=x.dtype) + pad_utterance = torch.cat([cache, utterance], dim=2) # (B, D, cache + U) # depth-wise conv on utterance utterance = self.depthwise_conv(pad_utterance) # (B, D, U) @@ -355,9 +344,7 @@ class ConvolutionModule(nn.Module): right_context = self.depthwise_conv( pad_right_context ) # (num_segs * B, D, right_context_length) - right_context = self._merge_right_context( - right_context, B - ) # (B, D, R) + right_context = self._merge_right_context(right_context, B) # (B, D, R) x = torch.cat([right_context, utterance], dim=2) # (B, D, R + U) x = self.deriv_balancer2(x) @@ -458,8 +445,7 @@ class EmformerAttention(nn.Module): if embed_dim % nhead != 0: raise ValueError( - f"embed_dim ({embed_dim}) is not a multiple of" - f"nhead ({nhead})." + f"embed_dim ({embed_dim}) is not a multiple ofnhead ({nhead})." ) self.embed_dim = embed_dim @@ -469,9 +455,7 @@ class EmformerAttention(nn.Module): self.head_dim = embed_dim // nhead self.dropout = dropout - self.emb_to_key_value = ScaledLinear( - embed_dim, 2 * embed_dim, bias=True - ) + self.emb_to_key_value = ScaledLinear(embed_dim, 2 * embed_dim, bias=True) self.emb_to_query = ScaledLinear(embed_dim, embed_dim, bias=True) self.out_proj = ScaledLinear( embed_dim, embed_dim, bias=True, initial_scale=0.25 @@ -513,9 +497,7 @@ class EmformerAttention(nn.Module): if padding_mask is not None: Q = attention_weights.size(1) B = attention_weights.size(0) // self.nhead - attention_weights_float = attention_weights_float.view( - B, self.nhead, Q, -1 - ) + attention_weights_float = attention_weights_float.view(B, self.nhead, Q, -1) attention_weights_float = attention_weights_float.masked_fill( padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), self.negative_inf, @@ -561,16 +543,12 @@ class EmformerAttention(nn.Module): # [memory, right context, left context, uttrance] # this is used in inference mode key = torch.cat([key[: M + R], left_context_key, key[M + R :]]) - value = torch.cat( - [value[: M + R], left_context_val, value[M + R :]] - ) + value = torch.cat([value[: M + R], left_context_val, value[M + R :]]) Q = query.size(0) # KV = key.size(0) reshaped_query, reshaped_key, reshaped_value = [ - tensor.contiguous() - .view(-1, B * self.nhead, self.head_dim) - .transpose(0, 1) + tensor.contiguous().view(-1, B * self.nhead, self.head_dim).transpose(0, 1) for tensor in [query, key, value] ] # (B * nhead, Q or KV, head_dim) attention_weights = torch.bmm( @@ -585,9 +563,7 @@ class EmformerAttention(nn.Module): # compute attention outputs attention = torch.bmm(attention_probs, reshaped_value) assert attention.shape == (B * self.nhead, Q, self.head_dim) - attention = ( - attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) - ) + attention = attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) # apply output projection output_right_context_utterance = self.out_proj(attention) @@ -905,13 +881,11 @@ class EmformerEncoderLayer(nn.Module): right_context = right_context_utterance[:R] if self.use_memory: - memory = self.summary_op(utterance.permute(1, 2, 0)).permute( - 2, 0, 1 - )[:-1, :, :] + memory = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ + :-1, :, : + ] else: - memory = torch.empty(0).to( - dtype=utterance.dtype, device=utterance.device - ) + memory = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) output_right_context_utterance = self.attention( utterance=utterance, right_context=right_context, @@ -948,18 +922,12 @@ class EmformerEncoderLayer(nn.Module): left_context_val = attn_cache[2] if self.use_memory: - memory = self.summary_op(utterance.permute(1, 2, 0)).permute( - 2, 0, 1 - )[:1, :, :] + memory = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ + :1, :, : + ] else: - memory = torch.empty(0).to( - dtype=utterance.dtype, device=utterance.device - ) - ( - output_right_context_utterance, - next_key, - next_val, - ) = self.attention.infer( + memory = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) + (output_right_context_utterance, next_key, next_val,) = self.attention.infer( utterance=utterance, right_context=right_context, memory=pre_memory, @@ -967,9 +935,7 @@ class EmformerEncoderLayer(nn.Module): left_context_val=left_context_val, padding_mask=padding_mask, ) - attn_cache = self._update_attn_cache( - next_key, next_val, memory, attn_cache - ) + attn_cache = self._update_attn_cache(next_key, next_val, memory, attn_cache) return output_right_context_utterance, attn_cache def forward( @@ -1226,9 +1192,7 @@ class EmformerEncoder(nn.Module): def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor: """Hard copy each chunk's right context and concat them.""" T = x.shape[0] - num_chunks = math.ceil( - (T - self.right_context_length) / self.chunk_length - ) + num_chunks = math.ceil((T - self.right_context_length) / self.chunk_length) # first (num_chunks - 1) right context block intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length @@ -1247,9 +1211,7 @@ class EmformerEncoder(nn.Module): right_context_blocks = x[indexes.reshape(-1)] return right_context_blocks - def _gen_attention_mask_col_widths( - self, chunk_idx: int, U: int - ) -> List[int]: + def _gen_attention_mask_col_widths(self, chunk_idx: int, U: int) -> List[int]: """Calculate column widths (key, value) in attention mask for the chunk_idx chunk.""" num_chunks = math.ceil(U / self.chunk_length) @@ -1549,12 +1511,8 @@ class EmformerEncoder(nn.Module): attn_caches = [ [ torch.zeros(self.memory_size, self.d_model, device=device), - torch.zeros( - self.left_context_length, self.d_model, device=device - ), - torch.zeros( - self.left_context_length, self.d_model, device=device - ), + torch.zeros(self.left_context_length, self.d_model, device=device), + torch.zeros(self.left_context_length, self.d_model, device=device), ] for _ in range(self.num_encoder_layers) ] @@ -1599,17 +1557,11 @@ class Emformer(EncoderInterface): raise NotImplementedError( "chunk_length must be a mutiple of subsampling_factor." ) - if ( - left_context_length != 0 - and left_context_length % subsampling_factor != 0 - ): + if left_context_length != 0 and left_context_length % subsampling_factor != 0: raise NotImplementedError( "left_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) - if ( - right_context_length != 0 - and right_context_length % subsampling_factor != 0 - ): + if right_context_length != 0 and right_context_length % subsampling_factor != 0: raise NotImplementedError( "right_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) @@ -1672,9 +1624,7 @@ class Emformer(EncoderInterface): x_lens = (((x_lens - 1) >> 1) - 1) >> 1 assert x.size(0) == x_lens.max().item() - output, output_lengths = self.encoder( - x, x_lens, warmup=warmup - ) # (T, N, C) + output, output_lengths = self.encoder(x, x_lens, warmup=warmup) # (T, N, C) output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py index ab15e0241..4d05b367c 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py @@ -103,9 +103,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -136,19 +138,20 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) add_model_arguments(parser) @@ -181,13 +184,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -210,13 +212,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -244,7 +245,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -279,9 +280,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py index 71150392d..0486ac2eb 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py @@ -113,8 +113,9 @@ def get_parser(): "--epoch", type=int, default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( @@ -131,20 +132,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -211,8 +216,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -371,9 +375,7 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner( - current_encoder_out, decoder_out, project_input=False - ) + logits = model.joiner(current_encoder_out, decoder_out, project_input=False) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -390,9 +392,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -551,14 +551,10 @@ def decode_one_chunk( feature_list, batch_first=True, padding_value=LOG_EPSILON ).to(device) feature_lens = torch.tensor(feature_len_list, device=device) - num_processed_frames = torch.tensor( - num_processed_frames_list, device=device - ) + num_processed_frames = torch.tensor(num_processed_frames_list, device=device) # Make sure it has at least 1 frame after subsampling, first-and-last-frame cutting, and right context cutting # noqa - tail_length = ( - 3 * params.subsampling_factor + params.right_context_length + 3 - ) + tail_length = 3 * params.subsampling_factor + params.right_context_length + 3 if features.size(1) < tail_length: pad_length = tail_length - features.size(1) feature_lens += pad_length @@ -605,9 +601,7 @@ def decode_one_chunk( max_states=params.max_states, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") # Update cached states of each stream state_list = unstack_states(states) @@ -782,8 +776,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -831,9 +824,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -867,13 +858,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -896,13 +886,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -930,7 +919,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py index 2bbc45d78..2c2593b56 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py @@ -95,9 +95,7 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -265,42 +263,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -636,11 +637,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -668,23 +665,16 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -871,9 +861,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -981,7 +969,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/local/add_alignment_librispeech.py b/egs/librispeech/ASR/local/add_alignment_librispeech.py index fe6a26c51..cc34a72d8 100755 --- a/egs/librispeech/ASR/local/add_alignment_librispeech.py +++ b/egs/librispeech/ASR/local/add_alignment_librispeech.py @@ -157,9 +157,7 @@ def add_alignment( for ali_path in part_ali_dir.rglob("*.alignment.txt"): ali = parse_alignments(ali_path) alignments.update(ali) - logging.info( - f"{part} has {len(alignments.keys())} cuts with alignments." - ) + logging.info(f"{part} has {len(alignments.keys())} cuts with alignments.") # add alignment attribute and write out cuts_in = load_manifest_lazy(cuts_in_path) @@ -170,18 +168,14 @@ def add_alignment( if origin_id in alignments: ali = alignments[origin_id] else: - logging.info( - f"Warning: {origin_id} does not have alignment." - ) + logging.info(f"Warning: {origin_id} does not have alignment.") ali = [] subcut.alignment = {"word": ali} writer.write(cut, flush=True) def main(): - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) parser = get_parser() diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py index 9a35750e0..295156ed5 100755 --- a/egs/librispeech/ASR/local/compile_hlg.py +++ b/egs/librispeech/ASR/local/compile_hlg.py @@ -150,9 +150,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/compile_lg.py b/egs/librispeech/ASR/local/compile_lg.py index 45c4b7f5f..19bf3bff4 100755 --- a/egs/librispeech/ASR/local/compile_lg.py +++ b/egs/librispeech/ASR/local/compile_lg.py @@ -132,9 +132,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py index c0c7ef8c5..97750f3ea 100644 --- a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py +++ b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py @@ -80,9 +80,7 @@ def compute_fbank_gigaspeech_dev_test(): def main(): - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_gigaspeech_dev_test() diff --git a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py index 5587106e5..37fce11f4 100644 --- a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py +++ b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py @@ -48,8 +48,10 @@ def get_parser(): "--batch-duration", type=float, default=600.0, - help="The maximum number of audio seconds in a batch." - "Determines batch size dynamically.", + help=( + "The maximum number of audio seconds in a batch." + "Determines batch size dynamically." + ), ) parser.add_argument( @@ -144,9 +146,7 @@ def main(): date_time = now.strftime("%Y-%m-%d-%H-%M-%S") log_filename = "log-compute_fbank_gigaspeech_splits" - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" log_filename = f"{log_filename}-{date_time}" logging.basicConfig( diff --git a/egs/librispeech/ASR/local/compute_fbank_librispeech.py b/egs/librispeech/ASR/local/compute_fbank_librispeech.py index ce7d087f0..9f8503814 100755 --- a/egs/librispeech/ASR/local/compute_fbank_librispeech.py +++ b/egs/librispeech/ASR/local/compute_fbank_librispeech.py @@ -112,9 +112,7 @@ def compute_fbank_librispeech(bpe_model: Optional[str] = None): if "train" in partition: cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -128,9 +126,7 @@ def compute_fbank_librispeech(bpe_model: Optional[str] = None): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() diff --git a/egs/librispeech/ASR/local/compute_fbank_musan.py b/egs/librispeech/ASR/local/compute_fbank_musan.py index 056da29e5..4a4093ae4 100755 --- a/egs/librispeech/ASR/local/compute_fbank_musan.py +++ b/egs/librispeech/ASR/local/compute_fbank_musan.py @@ -83,9 +83,7 @@ def compute_fbank_musan(): # create chunks of Musan with duration 5 - 10 seconds musan_cuts = ( CutSet.from_manifests( - recordings=combine( - part["recordings"] for part in manifests.values() - ) + recordings=combine(part["recordings"] for part in manifests.values()) ) .cut_into_windows(10.0) .filter(lambda c: c.duration > 5) @@ -101,9 +99,7 @@ def compute_fbank_musan(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_musan() diff --git a/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py b/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py index 133499c8b..f149b7871 100755 --- a/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py +++ b/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py @@ -46,21 +46,19 @@ def get_args(): parser.add_argument( "--transcript", type=str, - help="The input transcript file." - "We assume that the transcript file consists of " - "lines. Each line consists of space separated words.", + help=( + "The input transcript file." + "We assume that the transcript file consists of " + "lines. Each line consists of space separated words." + ), ) parser.add_argument("--lexicon", type=str, help="The input lexicon file.") - parser.add_argument( - "--oov", type=str, default="", help="The OOV word." - ) + parser.add_argument("--oov", type=str, default="", help="The OOV word.") return parser.parse_args() -def process_line( - lexicon: Dict[str, List[str]], line: str, oov_token: str -) -> None: +def process_line(lexicon: Dict[str, List[str]], line: str, oov_token: str) -> None: """ Args: lexicon: diff --git a/egs/librispeech/ASR/local/download_lm.py b/egs/librispeech/ASR/local/download_lm.py index 030122aa7..3518db524 100755 --- a/egs/librispeech/ASR/local/download_lm.py +++ b/egs/librispeech/ASR/local/download_lm.py @@ -87,9 +87,7 @@ def main(out_dir: str): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/filter_cuts.py b/egs/librispeech/ASR/local/filter_cuts.py index dff98a954..fbcc9e24a 100644 --- a/egs/librispeech/ASR/local/filter_cuts.py +++ b/egs/librispeech/ASR/local/filter_cuts.py @@ -79,8 +79,7 @@ def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor): total += 1 if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) removed += 1 return False @@ -125,8 +124,7 @@ def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor): ans = cut_set.filter(remove_short_and_long_utterances).to_eager() ratio = removed / total * 100 logging.info( - f"Removed {removed} cuts from {total} cuts. " - f"{ratio:.3f}% data is removed." + f"Removed {removed} cuts from {total} cuts. {ratio:.3f}% data is removed." ) return ans @@ -155,9 +153,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/generate_unique_lexicon.py b/egs/librispeech/ASR/local/generate_unique_lexicon.py index 566c0743d..3459c2f5a 100755 --- a/egs/librispeech/ASR/local/generate_unique_lexicon.py +++ b/egs/librispeech/ASR/local/generate_unique_lexicon.py @@ -91,9 +91,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/prepare_lang_bpe.py b/egs/librispeech/ASR/local/prepare_lang_bpe.py index dec8a7442..e121aefa9 100755 --- a/egs/librispeech/ASR/local/prepare_lang_bpe.py +++ b/egs/librispeech/ASR/local/prepare_lang_bpe.py @@ -150,9 +150,7 @@ def generate_lexicon( words_pieces_ids: List[List[int]] = sp.encode(words, out_type=int) # Now convert word piece IDs back to word piece strings. - words_pieces: List[List[str]] = [ - sp.id_to_piece(ids) for ids in words_pieces_ids - ] + words_pieces: List[List[str]] = [sp.id_to_piece(ids) for ids in words_pieces_ids] lexicon = [] for word, pieces in zip(words, words_pieces): diff --git a/egs/librispeech/ASR/local/prepare_lm_training_data.py b/egs/librispeech/ASR/local/prepare_lm_training_data.py index 5070341f1..70343fef7 100755 --- a/egs/librispeech/ASR/local/prepare_lm_training_data.py +++ b/egs/librispeech/ASR/local/prepare_lm_training_data.py @@ -137,8 +137,7 @@ def main(): for i in range(num_sentences): if step and i % step == 0: logging.info( - f"Processed number of lines: {i} " - f"({i/num_sentences*100: .3f}%)" + f"Processed number of lines: {i} ({i/num_sentences*100: .3f}%)" ) word_ids = sentences[i] @@ -154,18 +153,14 @@ def main(): sentence_lengths[i] = token_ids.numel() - output["sentence_lengths"] = torch.tensor( - sentence_lengths, dtype=torch.int32 - ) + output["sentence_lengths"] = torch.tensor(sentence_lengths, dtype=torch.int32) torch.save(output, args.lm_archive) logging.info(f"Saved to {args.lm_archive}") if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/preprocess_gigaspeech.py b/egs/librispeech/ASR/local/preprocess_gigaspeech.py index 077f23039..8aa5e461d 100644 --- a/egs/librispeech/ASR/local/preprocess_gigaspeech.py +++ b/egs/librispeech/ASR/local/preprocess_gigaspeech.py @@ -119,9 +119,7 @@ def preprocess_giga_speech(): def main(): - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) preprocess_giga_speech() diff --git a/egs/librispeech/ASR/local/test_prepare_lang.py b/egs/librispeech/ASR/local/test_prepare_lang.py index d4cf62bba..74e025ad7 100755 --- a/egs/librispeech/ASR/local/test_prepare_lang.py +++ b/egs/librispeech/ASR/local/test_prepare_lang.py @@ -88,9 +88,7 @@ def test_read_lexicon(filename: str): fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.draw("L.pdf", title="L") - fsa_disambig = lexicon_to_fst( - lexicon_disambig, phone2id=phone2id, word2id=word2id - ) + fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id) fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.draw("L_disambig.pdf", title="L_disambig") diff --git a/egs/librispeech/ASR/local/validate_manifest.py b/egs/librispeech/ASR/local/validate_manifest.py index 7c57d629a..807aaf891 100755 --- a/egs/librispeech/ASR/local/validate_manifest.py +++ b/egs/librispeech/ASR/local/validate_manifest.py @@ -64,8 +64,7 @@ def validate_supervision_and_cut_time_bounds(c: Cut): if s.end > c.end: raise ValueError( - f"{c.id}: Supervision end time {s.end} is larger " - f"than cut end time {c.end}" + f"{c.id}: Supervision end time {s.end} is larger than cut end time {c.end}" ) @@ -85,9 +84,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py old mode 100755 new mode 100644 index 27414d717..e69de29bb --- a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py @@ -1,818 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./lstm_transducer_stateless/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./lstm_transducer_stateless/exp \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) beam search (not recommended) -./lstm_transducer_stateless/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./lstm_transducer_stateless/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./lstm_transducer_stateless/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./lstm_transducer_stateless/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search (one best) -./lstm_transducer_stateless/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./lstm_transducer_stateless/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(5) fast beam search (nbest) -./lstm_transducer_stateless/decode.py \ - --epoch 30 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless3/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(6) fast beam search (nbest oracle WER) -./lstm_transducer_stateless/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./lstm_transducer_stateless/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(7) fast beam search (with LG) -./lstm_transducer_stateless/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./lstm_transducer_stateless/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_LG \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 -""" - - -import argparse -import logging -import math -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_nbest, - fast_beam_search_nbest_LG, - fast_beam_search_nbest_oracle, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="lstm_transducer_stateless/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_bpe_500", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - - fast_beam_search_nbest - - fast_beam_search_nbest_oracle - - fast_beam_search_nbest_LG - If you use fast_beam_search_nbest_LG, you have to specify - `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, - fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", - ) - - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - batch: dict, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or LG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - # tail padding here to alleviate the tail deletion problem - num_tail_padded_frames = 35 - feature = torch.nn.functional.pad( - feature, - (0, 0, 0, num_tail_padded_frames), - mode="constant", - value=LOG_EPS, - ) - feature_lens += num_tail_padded_frames - - encoder_out, encoder_out_lens, _ = model.encoder( - x=feature, x_lens=feature_lens - ) - - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "fast_beam_search_nbest_LG": - hyp_tokens = fast_beam_search_nbest_LG( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in hyp_tokens: - hyps.append([word_table[i] for i in hyp]) - elif params.decoding_method == "fast_beam_search_nbest": - hyp_tokens = fast_beam_search_nbest( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "fast_beam_search_nbest_oracle": - hyp_tokens = fast_beam_search_nbest_oracle( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - ref_texts=sp.encode(supervisions["text"]), - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp).split()) - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif "fast_beam_search" in params.decoding_method: - key = f"beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" - if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - - return {key: hyps} - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or LG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - word_table=word_table, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "fast_beam_search_nbest", - "fast_beam_search_nbest_LG", - "fast_beam_search_nbest_oracle", - "modified_beam_search", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - if "nbest" in params.decoding_method: - params.suffix += f"-nbest-scale-{params.nbest_scale}" - params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": - lexicon = Lexicon(params.lang_dir) - word_table = lexicon.word_table - lg_filename = params.lang_dir / "LG.pt" - logging.info(f"Loading {lg_filename}") - decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) - decoding_graph.scores *= params.ngram_lm_scale - else: - word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) - else: - decoding_graph = None - word_table = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - librispeech = LibriSpeechAsrDataModule(args) - - test_clean_cuts = librispeech.test_clean_cuts() - test_other_cuts = librispeech.test_other_cuts() - - test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) - test_other_dl = librispeech.test_dataloaders(test_other_cuts) - - test_sets = ["test-clean", "test-other"] - test_dl = [test_clean_dl, test_other_dl] - - for test_set, test_dl in zip(test_sets, test_dl): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - sp=sp, - word_table=word_table, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/export.py b/egs/librispeech/ASR/lstm_transducer_stateless/export.py old mode 100755 new mode 100644 index 13dac6009..e69de29bb --- a/egs/librispeech/ASR/lstm_transducer_stateless/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/export.py @@ -1,388 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" - -Usage: - -(1) Export to torchscript model using torch.jit.trace() - -./lstm_transducer_stateless/export.py \ - --exp-dir ./lstm_transducer_stateless/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 35 \ - --avg 10 \ - --jit-trace 1 - -It will generate 3 files: `encoder_jit_trace.pt`, -`decoder_jit_trace.pt`, and `joiner_jit_trace.pt`. - -(2) Export `model.state_dict()` - -./lstm_transducer_stateless/export.py \ - --exp-dir ./lstm_transducer_stateless/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 35 \ - --avg 10 - -It will generate a file `pretrained.pt` in the given `exp_dir`. You can later -load it by `icefall.checkpoint.load_checkpoint()`. - -To use the generated file with `lstm_transducer_stateless/decode.py`, -you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/librispeech/ASR - ./lstm_transducer_stateless/decode.py \ - --exp-dir ./lstm_transducer_stateless/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 \ - --decoding-method greedy_search \ - --bpe-model data/lang_bpe_500/bpe.model - -Check ./pretrained.py for its usage. - -Note: If you don't want to train a model from scratch, we have -provided one for you. You can get it at - -https://huggingface.co/Zengwei/icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18 - -with the following commands: - - sudo apt-get install git-lfs - git lfs install - git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18 - # You will find the pre-trained model in icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18/exp -""" - -import argparse -import logging -from pathlib import Path - -import sentencepiece as spm -import torch -import torch.nn as nn -from scaling_converter import convert_scaled_to_non_scaled -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for averaging. - Note: Epoch counts from 0. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless3/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--jit-trace", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.trace. - It will generate 3 files: - - encoder_jit_trace.pt - - decoder_jit_trace.pt - - joiner_jit_trace.pt - - Check ./jit_pretrained.py for how to use them. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -def export_encoder_model_jit_trace( - encoder_model: nn.Module, - encoder_filename: str, -) -> None: - """Export the given encoder model with torch.jit.trace() - - Note: The warmup argument is fixed to 1. - - Args: - encoder_model: - The input encoder model - encoder_filename: - The filename to save the exported model. - """ - x = torch.zeros(1, 100, 80, dtype=torch.float32) - x_lens = torch.tensor([100], dtype=torch.int64) - states = encoder_model.get_init_states() - - traced_model = torch.jit.trace(encoder_model, (x, x_lens, states)) - traced_model.save(encoder_filename) - logging.info(f"Saved to {encoder_filename}") - - -def export_decoder_model_jit_trace( - decoder_model: nn.Module, - decoder_filename: str, -) -> None: - """Export the given decoder model with torch.jit.trace() - - Note: The argument need_pad is fixed to False. - - Args: - decoder_model: - The input decoder model - decoder_filename: - The filename to save the exported model. - """ - y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) - need_pad = torch.tensor([False]) - - traced_model = torch.jit.trace(decoder_model, (y, need_pad)) - traced_model.save(decoder_filename) - logging.info(f"Saved to {decoder_filename}") - - -def export_joiner_model_jit_trace( - joiner_model: nn.Module, - joiner_filename: str, -) -> None: - """Export the given joiner model with torch.jit.trace() - - Note: The argument project_input is fixed to True. A user should not - project the encoder_out/decoder_out by himself/herself. The exported joiner - will do that for the user. - - Args: - joiner_model: - The input joiner model - joiner_filename: - The filename to save the exported model. - - """ - encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] - decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] - encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) - decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) - - traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out)) - traced_model.save(joiner_filename) - logging.info(f"Saved to {joiner_filename}") - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to("cpu") - model.eval() - - if params.jit_trace is True: - convert_scaled_to_non_scaled(model, inplace=True) - logging.info("Using torch.jit.trace()") - encoder_filename = params.exp_dir / "encoder_jit_trace.pt" - export_encoder_model_jit_trace(model.encoder, encoder_filename) - - decoder_filename = params.exp_dir / "decoder_jit_trace.pt" - export_decoder_model_jit_trace(model.decoder, decoder_filename) - - joiner_filename = params.exp_dir / "joiner_jit_trace.pt" - export_joiner_model_jit_trace(model.joiner, joiner_filename) - else: - logging.info("Not using torchscript") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py old mode 100755 new mode 100644 index 594c33e4f..e69de29bb --- a/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py @@ -1,322 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads torchscript models, either exported by `torch.jit.trace()` -or by `torch.jit.script()`, and uses them to decode waves. -You can use the following command to get the exported models: - -./lstm_transducer_stateless/export.py \ - --exp-dir ./lstm_transducer_stateless/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 20 \ - --avg 10 \ - --jit-trace 1 - -Usage of this script: - -./lstm_transducer_stateless/jit_pretrained.py \ - --encoder-model-filename ./lstm_transducer_stateless/exp/encoder_jit_trace.pt \ - --decoder-model-filename ./lstm_transducer_stateless/exp/decoder_jit_trace.pt \ - --joiner-model-filename ./lstm_transducer_stateless/exp/joiner_jit_trace.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - /path/to/foo.wav \ - /path/to/bar.wav -""" - -import argparse -import logging -import math -from typing import List - -import kaldifeat -import sentencepiece as spm -import torch -import torchaudio -from torch.nn.utils.rnn import pad_sequence - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--encoder-model-filename", - type=str, - required=True, - help="Path to the encoder torchscript model. ", - ) - - parser.add_argument( - "--decoder-model-filename", - type=str, - required=True, - help="Path to the decoder torchscript model. ", - ) - - parser.add_argument( - "--joiner-model-filename", - type=str, - required=True, - help="Path to the joiner torchscript model. ", - ) - - parser.add_argument( - "--bpe-model", - type=str, - help="""Path to bpe.model.""", - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="Context size of the decoder model", - ) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) - # We use only the first channel - ans.append(wave[0]) - return ans - - -def greedy_search( - decoder: torch.jit.ScriptModule, - joiner: torch.jit.ScriptModule, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - context_size: int, -) -> List[List[int]]: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - Args: - decoder: - The decoder model. - joiner: - The joiner model. - encoder_out: - A 3-D tensor of shape (N, T, C) - encoder_out_lens: - A 1-D tensor of shape (N,). - context_size: - The context size of the decoder model. - Returns: - Return the decoded results for each utterance. - """ - assert encoder_out.ndim == 3 - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - device = encoder_out.device - blank_id = 0 # hard-code to 0 - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - hyps = [[blank_id] * context_size for _ in range(N)] - - decoder_input = torch.tensor( - hyps, - device=device, - dtype=torch.int64, - ) # (N, context_size) - - decoder_out = decoder( - decoder_input, - need_pad=torch.tensor([False]), - ).squeeze(1) - - offset = 0 - for batch_size in batch_size_list: - start = offset - end = offset + batch_size - current_encoder_out = packed_encoder_out.data[start:end] - current_encoder_out = current_encoder_out - # current_encoder_out's shape: (batch_size, encoder_out_dim) - offset = end - - decoder_out = decoder_out[:batch_size] - - logits = joiner( - current_encoder_out, - decoder_out, - ) - # logits'shape (batch_size, vocab_size) - - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v != blank_id: - hyps[i].append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = [h[-context_size:] for h in hyps[:batch_size]] - decoder_input = torch.tensor( - decoder_input, - device=device, - dtype=torch.int64, - ) - decoder_out = decoder( - decoder_input, - need_pad=torch.tensor([False]), - ) - decoder_out = decoder_out.squeeze(1) - - sorted_ans = [h[context_size:] for h in hyps] - ans = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - encoder = torch.jit.load(args.encoder_model_filename) - decoder = torch.jit.load(args.decoder_model_filename) - joiner = torch.jit.load(args.joiner_model_filename) - - encoder.eval() - decoder.eval() - joiner.eval() - - encoder.to(device) - decoder.to(device) - joiner.to(device) - - sp = spm.SentencePieceProcessor() - sp.load(args.bpe_model) - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = args.sample_rate - opts.mel_opts.num_bins = 80 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {args.sound_files}") - waves = read_sound_files( - filenames=args.sound_files, - expected_sample_rate=args.sample_rate, - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence( - features, - batch_first=True, - padding_value=math.log(1e-10), - ) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - states = encoder.get_init_states(batch_size=features.size(0), device=device) - - encoder_out, encoder_out_lens, _ = encoder( - x=features, - x_lens=feature_lengths, - states=states, - ) - - hyps = greedy_search( - decoder=decoder, - joiner=joiner, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - context_size=args.context_size, - ) - s = "\n" - for filename, hyp in zip(args.sound_files, hyps): - words = sp.decode(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py b/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py index c54a4c478..e69de29bb 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py @@ -1,871 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -import math -from typing import List, Optional, Tuple - -import torch -from encoder_interface import EncoderInterface -from scaling import ( - ActivationBalancer, - BasicNorm, - DoubleSwish, - ScaledConv2d, - ScaledLinear, - ScaledLSTM, -) -from torch import nn - -LOG_EPSILON = math.log(1e-10) - - -def unstack_states( - states: Tuple[torch.Tensor, torch.Tensor] -) -> List[Tuple[torch.Tensor, torch.Tensor]]: - """ - Unstack the lstm states corresponding to a batch of utterances into a list - of states, where the i-th entry is the state from the i-th utterance. - - Args: - states: - A tuple of 2 elements. - ``states[0]`` is the lstm hidden states, of a batch of utterance. - ``states[1]`` is the lstm cell states, of a batch of utterances. - - Returns: - A list of states. - ``states[i]`` is a tuple of 2 elememts of i-th utterance. - ``states[i][0]`` is the lstm hidden states of i-th utterance. - ``states[i][1]`` is the lstm cell states of i-th utterance. - """ - hidden_states, cell_states = states - - list_hidden_states = hidden_states.unbind(dim=1) - list_cell_states = cell_states.unbind(dim=1) - - ans = [ - (h.unsqueeze(1), c.unsqueeze(1)) - for (h, c) in zip(list_hidden_states, list_cell_states) - ] - return ans - - -def stack_states( - states_list: List[Tuple[torch.Tensor, torch.Tensor]] -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Stack list of lstm states corresponding to separate utterances into a single - lstm state so that it can be used as an input for lstm when those utterances - are formed into a batch. - - Args: - state_list: - Each element in state_list corresponds to the lstm state for a single - utterance. - ``states[i]`` is a tuple of 2 elememts of i-th utterance. - ``states[i][0]`` is the lstm hidden states of i-th utterance. - ``states[i][1]`` is the lstm cell states of i-th utterance. - - - Returns: - A new state corresponding to a batch of utterances. - It is a tuple of 2 elements. - ``states[0]`` is the lstm hidden states, of a batch of utterance. - ``states[1]`` is the lstm cell states, of a batch of utterances. - """ - hidden_states = torch.cat([s[0] for s in states_list], dim=1) - cell_states = torch.cat([s[1] for s in states_list], dim=1) - ans = (hidden_states, cell_states) - return ans - - -class RNN(EncoderInterface): - """ - Args: - num_features (int): - Number of input features. - subsampling_factor (int): - Subsampling factor of encoder (convolution layers before lstm layers) (default=4). # noqa - d_model (int): - Output dimension (default=512). - dim_feedforward (int): - Feedforward dimension (default=2048). - rnn_hidden_size (int): - Hidden dimension for lstm layers (default=1024). - num_encoder_layers (int): - Number of encoder layers (default=12). - dropout (float): - Dropout rate (default=0.1). - layer_dropout (float): - Dropout value for model-level warmup (default=0.075). - aux_layer_period (int): - Period of auxiliary layers used for random combiner during training. - If set to 0, will not use the random combiner (Default). - You can set a positive integer to use the random combiner, e.g., 3. - is_pnnx: - True to make this class exportable via PNNX. - """ - - def __init__( - self, - num_features: int, - subsampling_factor: int = 4, - d_model: int = 512, - dim_feedforward: int = 2048, - rnn_hidden_size: int = 1024, - num_encoder_layers: int = 12, - dropout: float = 0.1, - layer_dropout: float = 0.075, - aux_layer_period: int = 0, - is_pnnx: bool = False, - ) -> None: - super(RNN, self).__init__() - - self.num_features = num_features - self.subsampling_factor = subsampling_factor - if subsampling_factor != 4: - raise NotImplementedError("Support only 'subsampling_factor=4'.") - - # self.encoder_embed converts the input of shape (N, T, num_features) - # to the shape (N, T//subsampling_factor, d_model). - # That is, it does two things simultaneously: - # (1) subsampling: T -> T//subsampling_factor - # (2) embedding: num_features -> d_model - self.encoder_embed = Conv2dSubsampling( - num_features, - d_model, - is_pnnx=is_pnnx, - ) - - self.is_pnnx = is_pnnx - - self.num_encoder_layers = num_encoder_layers - self.d_model = d_model - self.rnn_hidden_size = rnn_hidden_size - - encoder_layer = RNNEncoderLayer( - d_model=d_model, - dim_feedforward=dim_feedforward, - rnn_hidden_size=rnn_hidden_size, - dropout=dropout, - layer_dropout=layer_dropout, - ) - self.encoder = RNNEncoder( - encoder_layer, - num_encoder_layers, - aux_layers=list( - range( - num_encoder_layers // 3, - num_encoder_layers - 1, - aux_layer_period, - ) - ) - if aux_layer_period > 0 - else None, - ) - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - warmup: float = 1.0, - ) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """ - Args: - x: - The input tensor. Its shape is (N, T, C), where N is the batch size, - T is the sequence length, C is the feature dimension. - x_lens: - A tensor of shape (N,), containing the number of frames in `x` - before padding. - states: - A tuple of 2 tensors (optional). It is for streaming inference. - states[0] is the hidden states of all layers, - with shape of (num_layers, N, d_model); - states[1] is the cell states of all layers, - with shape of (num_layers, N, rnn_hidden_size). - warmup: - A floating point value that gradually increases from 0 throughout - training; when it is >= 1.0 we are "fully warmed up". It is used - to turn modules on sequentially. - - Returns: - A tuple of 3 tensors: - - embeddings: its shape is (N, T', d_model), where T' is the output - sequence lengths. - - lengths: a tensor of shape (batch_size,) containing the number of - frames in `embeddings` before padding. - - updated states, whose shape is the same as the input states. - """ - x = self.encoder_embed(x) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - # lengths = ((x_lens - 3) // 2 - 1) // 2 # issue an warning - # - # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 - if not self.is_pnnx: - lengths = (((x_lens - 3) >> 1) - 1) >> 1 - else: - lengths1 = torch.floor((x_lens - 3) / 2) - lengths = torch.floor((lengths1 - 1) / 2) - lengths = lengths.to(x_lens) - - if not torch.jit.is_tracing(): - assert x.size(0) == lengths.max().item() - - if states is None: - x = self.encoder(x, warmup=warmup)[0] - # torch.jit.trace requires returned types to be the same as annotated # noqa - new_states = (torch.empty(0), torch.empty(0)) - else: - assert not self.training - assert len(states) == 2 - if not torch.jit.is_tracing(): - # for hidden state - assert states[0].shape == ( - self.num_encoder_layers, - x.size(1), - self.d_model, - ) - # for cell state - assert states[1].shape == ( - self.num_encoder_layers, - x.size(1), - self.rnn_hidden_size, - ) - x, new_states = self.encoder(x, states) - - x = x.permute(1, 0, 2) # (T, N, C) -> (N, T, C) - return x, lengths, new_states - - @torch.jit.export - def get_init_states( - self, batch_size: int = 1, device: torch.device = torch.device("cpu") - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Get model initial states.""" - # for rnn hidden states - hidden_states = torch.zeros( - (self.num_encoder_layers, batch_size, self.d_model), device=device - ) - cell_states = torch.zeros( - (self.num_encoder_layers, batch_size, self.rnn_hidden_size), - device=device, - ) - return (hidden_states, cell_states) - - -class RNNEncoderLayer(nn.Module): - """ - RNNEncoderLayer is made up of lstm and feedforward networks. - - Args: - d_model: - The number of expected features in the input (required). - dim_feedforward: - The dimension of feedforward network model (default=2048). - rnn_hidden_size: - The hidden dimension of rnn layer. - dropout: - The dropout value (default=0.1). - layer_dropout: - The dropout value for model-level warmup (default=0.075). - """ - - def __init__( - self, - d_model: int, - dim_feedforward: int, - rnn_hidden_size: int, - dropout: float = 0.1, - layer_dropout: float = 0.075, - ) -> None: - super(RNNEncoderLayer, self).__init__() - self.layer_dropout = layer_dropout - self.d_model = d_model - self.rnn_hidden_size = rnn_hidden_size - - assert rnn_hidden_size >= d_model, (rnn_hidden_size, d_model) - self.lstm = ScaledLSTM( - input_size=d_model, - hidden_size=rnn_hidden_size, - proj_size=d_model if rnn_hidden_size > d_model else 0, - num_layers=1, - dropout=0.0, - ) - self.feed_forward = nn.Sequential( - ScaledLinear(d_model, dim_feedforward), - ActivationBalancer(channel_dim=-1), - DoubleSwish(), - nn.Dropout(dropout), - ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), - ) - self.norm_final = BasicNorm(d_model) - - # try to ensure the output is close to zero-mean (or at least, zero-median). # noqa - self.balancer = ActivationBalancer( - channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 - ) - self.dropout = nn.Dropout(dropout) - - def forward( - self, - src: torch.Tensor, - states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - warmup: float = 1.0, - ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """ - Pass the input through the encoder layer. - - Args: - src: - The sequence to the encoder layer (required). - Its shape is (S, N, E), where S is the sequence length, - N is the batch size, and E is the feature number. - states: - A tuple of 2 tensors (optional). It is for streaming inference. - states[0] is the hidden states of all layers, - with shape of (1, N, d_model); - states[1] is the cell states of all layers, - with shape of (1, N, rnn_hidden_size). - warmup: - It controls selective bypass of of layers; if < 1.0, we will - bypass layers more frequently. - """ - src_orig = src - - warmup_scale = min(0.1 + warmup, 1.0) - # alpha = 1.0 means fully use this encoder layer, 0.0 would mean - # completely bypass it. - if self.training: - alpha = ( - warmup_scale - if torch.rand(()).item() <= (1.0 - self.layer_dropout) - else 0.1 - ) - else: - alpha = 1.0 - - # lstm module - if states is None: - src_lstm = self.lstm(src)[0] - # torch.jit.trace requires returned types be the same as annotated - new_states = (torch.empty(0), torch.empty(0)) - else: - assert not self.training - assert len(states) == 2 - if not torch.jit.is_tracing(): - # for hidden state - assert states[0].shape == (1, src.size(1), self.d_model) - # for cell state - assert states[1].shape == (1, src.size(1), self.rnn_hidden_size) - src_lstm, new_states = self.lstm(src, states) - src = self.dropout(src_lstm) + src - - # feed forward module - src = src + self.dropout(self.feed_forward(src)) - - src = self.norm_final(self.balancer(src)) - - if alpha != 1.0: - src = alpha * src + (1 - alpha) * src_orig - - return src, new_states - - -class RNNEncoder(nn.Module): - """ - RNNEncoder is a stack of N encoder layers. - - Args: - encoder_layer: - An instance of the RNNEncoderLayer() class (required). - num_layers: - The number of sub-encoder-layers in the encoder (required). - """ - - def __init__( - self, - encoder_layer: nn.Module, - num_layers: int, - aux_layers: Optional[List[int]] = None, - ) -> None: - super(RNNEncoder, self).__init__() - self.layers = nn.ModuleList( - [copy.deepcopy(encoder_layer) for i in range(num_layers)] - ) - self.num_layers = num_layers - self.d_model = encoder_layer.d_model - self.rnn_hidden_size = encoder_layer.rnn_hidden_size - - self.aux_layers: List[int] = [] - self.combiner: Optional[nn.Module] = None - if aux_layers is not None: - assert len(set(aux_layers)) == len(aux_layers) - assert num_layers - 1 not in aux_layers - self.aux_layers = aux_layers + [num_layers - 1] - self.combiner = RandomCombine( - num_inputs=len(self.aux_layers), - final_weight=0.5, - pure_prob=0.333, - stddev=2.0, - ) - - def forward( - self, - src: torch.Tensor, - states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - warmup: float = 1.0, - ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """ - Pass the input through the encoder layer in turn. - - Args: - src: - The sequence to the encoder layer (required). - Its shape is (S, N, E), where S is the sequence length, - N is the batch size, and E is the feature number. - states: - A tuple of 2 tensors (optional). It is for streaming inference. - states[0] is the hidden states of all layers, - with shape of (num_layers, N, d_model); - states[1] is the cell states of all layers, - with shape of (num_layers, N, rnn_hidden_size). - warmup: - It controls selective bypass of of layers; if < 1.0, we will - bypass layers more frequently. - """ - if states is not None: - assert not self.training - assert len(states) == 2 - if not torch.jit.is_tracing(): - # for hidden state - assert states[0].shape == ( - self.num_layers, - src.size(1), - self.d_model, - ) - # for cell state - assert states[1].shape == ( - self.num_layers, - src.size(1), - self.rnn_hidden_size, - ) - - output = src - - outputs = [] - - new_hidden_states = [] - new_cell_states = [] - - for i, mod in enumerate(self.layers): - if states is None: - output = mod(output, warmup=warmup)[0] - else: - layer_state = ( - states[0][i : i + 1, :, :], # h: (1, N, d_model) - states[1][i : i + 1, :, :], # c: (1, N, rnn_hidden_size) - ) - output, (h, c) = mod(output, layer_state) - new_hidden_states.append(h) - new_cell_states.append(c) - - if self.combiner is not None and i in self.aux_layers: - outputs.append(output) - - if self.combiner is not None: - output = self.combiner(outputs) - - if states is None: - new_states = (torch.empty(0), torch.empty(0)) - else: - new_states = ( - torch.cat(new_hidden_states, dim=0), - torch.cat(new_cell_states, dim=0), - ) - - return output, new_states - - -class Conv2dSubsampling(nn.Module): - """Convolutional 2D subsampling (to 1/4 length). - - Convert an input of shape (N, T, idim) to an output - with shape (N, T', odim), where - T' = ((T-3)//2-1)//2, which approximates T' == T//4 - - It is based on - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - layer1_channels: int = 8, - layer2_channels: int = 32, - layer3_channels: int = 128, - is_pnnx: bool = False, - ) -> None: - """ - Args: - in_channels: - Number of channels in. The input shape is (N, T, in_channels). - Caution: It requires: T >= 9, in_channels >= 9. - out_channels - Output dim. The output shape is (N, ((T-3)//2-1)//2, out_channels) - layer1_channels: - Number of channels in layer1 - layer1_channels: - Number of channels in layer2 - is_pnnx: - True if we are converting the model to PNNX format. - False otherwise. - """ - assert in_channels >= 9 - super().__init__() - - self.conv = nn.Sequential( - ScaledConv2d( - in_channels=1, - out_channels=layer1_channels, - kernel_size=3, - padding=0, - ), - ActivationBalancer(channel_dim=1), - DoubleSwish(), - ScaledConv2d( - in_channels=layer1_channels, - out_channels=layer2_channels, - kernel_size=3, - stride=2, - ), - ActivationBalancer(channel_dim=1), - DoubleSwish(), - ScaledConv2d( - in_channels=layer2_channels, - out_channels=layer3_channels, - kernel_size=3, - stride=2, - ), - ActivationBalancer(channel_dim=1), - DoubleSwish(), - ) - self.out = ScaledLinear( - layer3_channels * (((in_channels - 3) // 2 - 1) // 2), out_channels - ) - # set learn_eps=False because out_norm is preceded by `out`, and `out` - # itself has learned scale, so the extra degree of freedom is not - # needed. - self.out_norm = BasicNorm(out_channels, learn_eps=False) - # constrain median of output to be close to zero. - self.out_balancer = ActivationBalancer( - channel_dim=-1, min_positive=0.45, max_positive=0.55 - ) - - # ncnn supports only batch size == 1 - self.is_pnnx = is_pnnx - self.conv_out_dim = self.out.weight.shape[1] - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Subsample x. - - Args: - x: - Its shape is (N, T, idim). - - Returns: - Return a tensor of shape (N, ((T-3)//2-1)//2, odim) - """ - # On entry, x is (N, T, idim) - x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) - x = self.conv(x) - - if torch.jit.is_tracing() and self.is_pnnx: - x = x.permute(0, 2, 1, 3).reshape(1, -1, self.conv_out_dim) - x = self.out(x) - else: - # Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2) - b, c, t, f = x.size() - x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) - - # Now x is of shape (N, ((T-3)//2-1))//2, odim) - x = self.out_norm(x) - x = self.out_balancer(x) - return x - - -class RandomCombine(nn.Module): - """ - This module combines a list of Tensors, all with the same shape, to - produce a single output of that same shape which, in training time, - is a random combination of all the inputs; but which in test time - will be just the last input. - - The idea is that the list of Tensors will be a list of outputs of multiple - conformer layers. This has a similar effect as iterated loss. (See: - DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER - NETWORKS). - """ - - def __init__( - self, - num_inputs: int, - final_weight: float = 0.5, - pure_prob: float = 0.5, - stddev: float = 2.0, - ) -> None: - """ - Args: - num_inputs: - The number of tensor inputs, which equals the number of layers' - outputs that are fed into this module. E.g. in an 18-layer neural - net if we output layers 16, 12, 18, num_inputs would be 3. - final_weight: - The amount of weight or probability we assign to the - final layer when randomly choosing layers or when choosing - continuous layer weights. - pure_prob: - The probability, on each frame, with which we choose - only a single layer to output (rather than an interpolation) - stddev: - A standard deviation that we add to log-probs for computing - randomized weights. - - The method of choosing which layers, or combinations of layers, to use, - is conceptually as follows:: - - With probability `pure_prob`:: - With probability `final_weight`: choose final layer, - Else: choose random non-final layer. - Else:: - Choose initial log-weights that correspond to assigning - weight `final_weight` to the final layer and equal - weights to other layers; then add Gaussian noise - with variance `stddev` to these log-weights, and normalize - to weights (note: the average weight assigned to the - final layer here will not be `final_weight` if stddev>0). - """ - super().__init__() - assert 0 <= pure_prob <= 1, pure_prob - assert 0 < final_weight < 1, final_weight - assert num_inputs >= 1 - - self.num_inputs = num_inputs - self.final_weight = final_weight - self.pure_prob = pure_prob - self.stddev = stddev - - self.final_log_weight = ( - torch.tensor( - (final_weight / (1 - final_weight)) * (self.num_inputs - 1) - ) - .log() - .item() - ) - - def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: - """Forward function. - Args: - inputs: - A list of Tensor, e.g. from various layers of a transformer. - All must be the same shape, of (*, num_channels) - Returns: - A Tensor of shape (*, num_channels). In test mode - this is just the final input. - """ - num_inputs = self.num_inputs - assert len(inputs) == num_inputs - if not self.training or torch.jit.is_scripting(): - return inputs[-1] - - # Shape of weights: (*, num_inputs) - num_channels = inputs[0].shape[-1] - num_frames = inputs[0].numel() // num_channels - - ndim = inputs[0].ndim - # stacked_inputs: (num_frames, num_channels, num_inputs) - stacked_inputs = torch.stack(inputs, dim=ndim).reshape( - (num_frames, num_channels, num_inputs) - ) - - # weights: (num_frames, num_inputs) - weights = self._get_random_weights( - inputs[0].dtype, inputs[0].device, num_frames - ) - - weights = weights.reshape(num_frames, num_inputs, 1) - # ans: (num_frames, num_channels, 1) - ans = torch.matmul(stacked_inputs, weights) - # ans: (*, num_channels) - - ans = ans.reshape(inputs[0].shape[:-1] + (num_channels,)) - - # The following if causes errors for torch script in torch 1.6.0 - # if __name__ == "__main__": - # # for testing only... - # print("Weights = ", weights.reshape(num_frames, num_inputs)) - return ans - - def _get_random_weights( - self, dtype: torch.dtype, device: torch.device, num_frames: int - ) -> torch.Tensor: - """Return a tensor of random weights, of shape - `(num_frames, self.num_inputs)`, - Args: - dtype: - The data-type desired for the answer, e.g. float, double. - device: - The device needed for the answer. - num_frames: - The number of sets of weights desired - Returns: - A tensor of shape (num_frames, self.num_inputs), such that - `ans.sum(dim=1)` is all ones. - """ - pure_prob = self.pure_prob - if pure_prob == 0.0: - return self._get_random_mixed_weights(dtype, device, num_frames) - elif pure_prob == 1.0: - return self._get_random_pure_weights(dtype, device, num_frames) - else: - p = self._get_random_pure_weights(dtype, device, num_frames) - m = self._get_random_mixed_weights(dtype, device, num_frames) - return torch.where( - torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m - ) - - def _get_random_pure_weights( - self, dtype: torch.dtype, device: torch.device, num_frames: int - ): - """Return a tensor of random one-hot weights, of shape - `(num_frames, self.num_inputs)`, - Args: - dtype: - The data-type desired for the answer, e.g. float, double. - device: - The device needed for the answer. - num_frames: - The number of sets of weights desired. - Returns: - A one-hot tensor of shape `(num_frames, self.num_inputs)`, with - exactly one weight equal to 1.0 on each frame. - """ - final_prob = self.final_weight - - # final contains self.num_inputs - 1 in all elements - final = torch.full((num_frames,), self.num_inputs - 1, device=device) - # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. # noqa - nonfinal = torch.randint( - self.num_inputs - 1, (num_frames,), device=device - ) - - indexes = torch.where( - torch.rand(num_frames, device=device) < final_prob, final, nonfinal - ) - ans = torch.nn.functional.one_hot( - indexes, num_classes=self.num_inputs - ).to(dtype=dtype) - return ans - - def _get_random_mixed_weights( - self, dtype: torch.dtype, device: torch.device, num_frames: int - ): - """Return a tensor of random one-hot weights, of shape - `(num_frames, self.num_inputs)`, - Args: - dtype: - The data-type desired for the answer, e.g. float, double. - device: - The device needed for the answer. - num_frames: - The number of sets of weights desired. - Returns: - A tensor of shape (num_frames, self.num_inputs), which elements - in [0..1] that sum to one over the second axis, i.e. - `ans.sum(dim=1)` is all ones. - """ - logprobs = ( - torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device) - * self.stddev # noqa - ) - logprobs[:, -1] += self.final_log_weight - return logprobs.softmax(dim=1) - - -def _test_random_combine(final_weight: float, pure_prob: float, stddev: float): - print( - f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}" # noqa - ) - num_inputs = 3 - num_channels = 50 - m = RandomCombine( - num_inputs=num_inputs, - final_weight=final_weight, - pure_prob=pure_prob, - stddev=stddev, - ) - - x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)] - - y = m(x) - assert y.shape == x[0].shape - assert torch.allclose(y, x[0]) # .. since actually all ones. - - -def _test_random_combine_main(): - _test_random_combine(0.999, 0, 0.0) - _test_random_combine(0.5, 0, 0.0) - _test_random_combine(0.999, 0, 0.0) - _test_random_combine(0.5, 0, 0.3) - _test_random_combine(0.5, 1, 0.3) - _test_random_combine(0.5, 0.5, 0.3) - - feature_dim = 50 - c = RNN(num_features=feature_dim, d_model=128) - batch_size = 5 - seq_len = 20 - # Just make sure the forward pass runs. - f = c( - torch.randn(batch_size, seq_len, feature_dim), - torch.full((batch_size,), seq_len, dtype=torch.int64), - ) - f # to remove flake8 warnings - - -if __name__ == "__main__": - feature_dim = 80 - m = RNN( - num_features=feature_dim, - d_model=512, - rnn_hidden_size=1024, - dim_feedforward=2048, - num_encoder_layers=12, - ) - batch_size = 5 - seq_len = 20 - # Just make sure the forward pass runs. - f = m( - torch.randn(batch_size, seq_len, feature_dim), - torch.full((batch_size,), seq_len, dtype=torch.int64), - warmup=0.5, - ) - num_param = sum([p.numel() for p in m.parameters()]) - print(f"Number of model parameters: {num_param}") - - _test_random_combine_main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/model.py b/egs/librispeech/ASR/lstm_transducer_stateless/model.py index d71132b4a..e69de29bb 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/model.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/model.py @@ -1,210 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import Tuple - -import k2 -import torch -import torch.nn as nn -from encoder_interface import EncoderInterface -from scaling import ScaledLinear - -from icefall.utils import add_sos - - -class Transducer(nn.Module): - """It implements https://arxiv.org/pdf/1211.3711.pdf - "Sequence Transduction with Recurrent Neural Networks" - """ - - def __init__( - self, - encoder: EncoderInterface, - decoder: nn.Module, - joiner: nn.Module, - encoder_dim: int, - decoder_dim: int, - joiner_dim: int, - vocab_size: int, - ): - """ - Args: - encoder: - It is the transcription network in the paper. Its accepts - two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). - It returns two tensors: `logits` of shape (N, T, encoder_dm) and - `logit_lens` of shape (N,). - decoder: - It is the prediction network in the paper. Its input shape - is (N, U) and its output shape is (N, U, decoder_dim). - It should contain one attribute: `blank_id`. - joiner: - It has two inputs with shapes: (N, T, encoder_dim) and - (N, U, decoder_dim). - Its output shape is (N, T, U, vocab_size). Note that its output - contains unnormalized probs, i.e., not processed by log-softmax. - """ - super().__init__() - assert isinstance(encoder, EncoderInterface), type(encoder) - assert hasattr(decoder, "blank_id") - - self.encoder = encoder - self.decoder = decoder - self.joiner = joiner - - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) - self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - y: k2.RaggedTensor, - prune_range: int = 5, - am_scale: float = 0.0, - lm_scale: float = 0.0, - warmup: float = 1.0, - reduction: str = "sum", - delay_penalty: float = 0.0, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - x: - A 3-D tensor of shape (N, T, C). - x_lens: - A 1-D tensor of shape (N,). It contains the number of frames in `x` - before padding. - y: - A ragged tensor with 2 axes [utt][label]. It contains labels of each - utterance. - prune_range: - The prune range for rnnt loss, it means how many symbols(context) - we are considering for each frame to compute the loss. - am_scale: - The scale to smooth the loss with am (output of encoder network) - part - lm_scale: - The scale to smooth the loss with lm (output of predictor network) - part - warmup: - A value warmup >= 0 that determines which modules are active, values - warmup > 1 "are fully warmed up" and all modules will be active. - reduction: - "sum" to sum the losses over all utterances in the batch. - "none" to return the loss in a 1-D tensor for each utterance - in the batch. - delay_penalty: - A constant value used to penalize symbol delay, to encourage - streaming models to emit symbols earlier. - See https://github.com/k2-fsa/k2/issues/955 and - https://arxiv.org/pdf/2211.00490.pdf for more details. - Returns: - Return the transducer loss. - - Note: - Regarding am_scale & lm_scale, it will make the loss-function one of - the form: - lm_scale * lm_probs + am_scale * am_probs + - (1-lm_scale-am_scale) * combined_probs - """ - assert reduction in ("sum", "none"), reduction - assert x.ndim == 3, x.shape - assert x_lens.ndim == 1, x_lens.shape - assert y.num_axes == 2, y.num_axes - - assert x.size(0) == x_lens.size(0) == y.dim0 - - encoder_out, x_lens, _ = self.encoder(x, x_lens, warmup=warmup) - assert torch.all(x_lens > 0) - - # Now for the decoder, i.e., the prediction network - row_splits = y.shape.row_splits(1) - y_lens = row_splits[1:] - row_splits[:-1] - - blank_id = self.decoder.blank_id - sos_y = add_sos(y, sos_id=blank_id) - - # sos_y_padded: [B, S + 1], start with SOS. - sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) - - # decoder_out: [B, S + 1, decoder_dim] - decoder_out = self.decoder(sos_y_padded) - - # Note: y does not start with SOS - # y_padded : [B, S] - y_padded = y.pad(mode="constant", padding_value=0) - - y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) - boundary[:, 2] = y_lens - boundary[:, 3] = x_lens - - lm = self.simple_lm_proj(decoder_out) - am = self.simple_am_proj(encoder_out) - - with torch.cuda.amp.autocast(enabled=False): - simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( - lm=lm.float(), - am=am.float(), - symbols=y_padded, - termination_symbol=blank_id, - lm_only_scale=lm_scale, - am_only_scale=am_scale, - boundary=boundary, - reduction=reduction, - delay_penalty=delay_penalty, - return_grad=True, - ) - - # ranges : [B, T, prune_range] - ranges = k2.get_rnnt_prune_ranges( - px_grad=px_grad, - py_grad=py_grad, - boundary=boundary, - s_range=prune_range, - ) - - # am_pruned : [B, T, prune_range, encoder_dim] - # lm_pruned : [B, T, prune_range, decoder_dim] - am_pruned, lm_pruned = k2.do_rnnt_pruning( - am=self.joiner.encoder_proj(encoder_out), - lm=self.joiner.decoder_proj(decoder_out), - ranges=ranges, - ) - - # logits : [B, T, prune_range, vocab_size] - - # project_input=False since we applied the decoder's input projections - # prior to do_rnnt_pruning (this is an optimization for speed). - logits = self.joiner(am_pruned, lm_pruned, project_input=False) - - with torch.cuda.amp.autocast(enabled=False): - pruned_loss = k2.rnnt_loss_pruned( - logits=logits.float(), - symbols=y_padded, - ranges=ranges, - termination_symbol=blank_id, - boundary=boundary, - delay_penalty=delay_penalty, - reduction=reduction, - ) - - return (simple_loss, pruned_loss) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py old mode 100755 new mode 100644 index 2a6e2adc6..e69de29bb --- a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py @@ -1,352 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -(1) greedy search -./lstm_transducer_stateless/pretrained.py \ - --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) beam search -./lstm_transducer_stateless/pretrained.py \ - --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) modified beam search -./lstm_transducer_stateless/pretrained.py \ - --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method modified_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(4) fast beam search -./lstm_transducer_stateless/pretrained.py \ - --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method fast_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -You can also use `./lstm_transducer_stateless/exp/epoch-xx.pt`. - -Note: ./lstm_transducer_stateless/exp/pretrained.pt is generated by -./lstm_transducer_stateless/export.py -""" - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import sentencepiece as spm -import torch -import torchaudio -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--bpe-model", - type=str, - help="""Path to bpe.model.""", - ) - - parser.add_argument( - "--method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. Used only when - --method is greedy_search. - """, - ) - - add_model_arguments(parser) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) - # We use only the first channel - ans.append(wave[0]) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - - params.update(vars(args)) - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - model.device = device - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - encoder_out, encoder_out_lens, _ = model.encoder( - x=features, x_lens=feature_lengths - ) - - num_waves = encoder_out.size(0) - hyps = [] - msg = f"Using {params.method}" - if params.method == "beam_search": - msg += f" with beam size {params.beam_size}" - logging.info(msg) - - if params.method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - else: - for i in range(num_waves): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError(f"Unsupported method: {params.method}") - - hyps.append(sp.decode(hyp).split()) - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/stream.py b/egs/librispeech/ASR/lstm_transducer_stateless/stream.py index 97d890c82..e69de29bb 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/stream.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/stream.py @@ -1,148 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from typing import List, Optional, Tuple - -import k2 -import torch -from beam_search import Hypothesis, HypothesisList - -from icefall.utils import AttributeDict - - -class Stream(object): - def __init__( - self, - params: AttributeDict, - cut_id: str, - decoding_graph: Optional[k2.Fsa] = None, - device: torch.device = torch.device("cpu"), - LOG_EPS: float = math.log(1e-10), - ) -> None: - """ - Args: - params: - It's the return value of :func:`get_params`. - cut_id: - The cut id of the current stream. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - device: - The device to run this stream. - LOG_EPS: - A float value used for padding. - """ - self.LOG_EPS = LOG_EPS - self.cut_id = cut_id - - # Containing attention caches and convolution caches - self.states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None - - # It uses different attributes for different decoding methods. - self.context_size = params.context_size - self.decoding_method = params.decoding_method - if params.decoding_method == "greedy_search": - self.hyp = [params.blank_id] * params.context_size - elif params.decoding_method == "modified_beam_search": - self.hyps = HypothesisList() - self.hyps.add( - Hypothesis( - ys=[params.blank_id] * params.context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - ) - ) - elif params.decoding_method == "fast_beam_search": - # feature_len is needed to get partial results. - # The rnnt_decoding_stream for fast_beam_search. - self.rnnt_decoding_stream: k2.RnntDecodingStream = ( - k2.RnntDecodingStream(decoding_graph) - ) - self.hyp: Optional[List[int]] = None - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - - self.ground_truth: str = "" - - self.feature: Optional[torch.Tensor] = None - # Make sure all feature frames can be used. - # We aim to obtain 1 frame after subsampling. - self.chunk_length = params.subsampling_factor - self.pad_length = 5 - self.num_frames = 0 - self.num_processed_frames = 0 - - # After all feature frames are processed, we set this flag to True - self._done = False - - def set_feature(self, feature: torch.Tensor) -> None: - assert feature.dim() == 2, feature.dim() - # tail padding here to alleviate the tail deletion problem - num_tail_padded_frames = 35 - self.num_frames = feature.size(0) + num_tail_padded_frames - self.feature = torch.nn.functional.pad( - feature, - (0, 0, 0, self.pad_length + num_tail_padded_frames), - mode="constant", - value=self.LOG_EPS, - ) - - def get_feature_chunk(self) -> torch.Tensor: - """Get a chunk of feature frames. - - Returns: - A tensor of shape (ret_length, feature_dim). - """ - update_length = min( - self.num_frames - self.num_processed_frames, self.chunk_length - ) - ret_length = update_length + self.pad_length - - ret_feature = self.feature[ - self.num_processed_frames : self.num_processed_frames + ret_length - ] - # Cut off used frames. - # self.feature = self.feature[update_length:] - - self.num_processed_frames += update_length - if self.num_processed_frames >= self.num_frames: - self._done = True - - return ret_feature - - @property - def id(self) -> str: - return self.cut_id - - @property - def done(self) -> bool: - """Return True if all feature frames are processed.""" - return self._done - - def decoding_result(self) -> List[int]: - """Obtain current decoding result.""" - if self.decoding_method == "greedy_search": - return self.hyp[self.context_size :] - elif self.decoding_method == "modified_beam_search": - best_hyp = self.hyps.get_most_probable(length_norm=True) - return best_hyp.ys[self.context_size :] - else: - assert self.decoding_method == "fast_beam_search" - return self.hyp diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py old mode 100755 new mode 100644 index d6376bdc0..e69de29bb --- a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py @@ -1,968 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./lstm_transducer_stateless/streaming_decode.py \ - --epoch 35 \ - --avg 10 \ - --exp-dir lstm_transducer_stateless/exp \ - --num-decode-streams 2000 \ - --num-encoder-layers 12 \ - --rnn-hidden-size 1024 \ - --decoding-method greedy_search \ - --use-averaged-model True - -(2) modified beam search -./lstm_transducer_stateless/streaming_decode.py \ - --epoch 35 \ - --avg 10 \ - --exp-dir lstm_transducer_stateless/exp \ - --num-decode-streams 2000 \ - --num-encoder-layers 12 \ - --rnn-hidden-size 1024 \ - --decoding-method modified_beam_search \ - --use-averaged-model True \ - --beam-size 4 - -(3) fast beam search -./lstm_transducer_stateless/streaming_decode.py \ - --epoch 35 \ - --avg 10 \ - --exp-dir lstm_transducer_stateless/exp \ - --num-decode-streams 2000 \ - --num-encoder-layers 12 \ - --rnn-hidden-size 1024 \ - --decoding-method fast_beam_search \ - --use-averaged-model True \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 -""" -import argparse -import logging -import warnings -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import numpy as np -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule -from beam_search import Hypothesis, HypothesisList, get_hyps_shape -from kaldifeat import Fbank, FbankOptions -from lhotse import CutSet -from lstm import LOG_EPSILON, stack_states, unstack_states -from stream import Stream -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.decode import one_best_decoding -from icefall.utils import ( - AttributeDict, - get_texts, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=False, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="transducer_emformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An interger indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--sampling-rate", - type=float, - default=16000, - help="Sample rate of the audio", - ) - - parser.add_argument( - "--num-decode-streams", - type=int, - default=2000, - help="The number of streams that can be decoded in parallel", - ) - - add_model_arguments(parser) - - return parser - - -def greedy_search( - model: nn.Module, - encoder_out: torch.Tensor, - streams: List[Stream], -) -> None: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C), where N >= 1. - streams: - A list of Stream objects. - """ - assert len(streams) == encoder_out.size(0) - assert encoder_out.ndim == 3 - - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - device = next(model.parameters()).device - T = encoder_out.size(1) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - decoder_input = torch.tensor( - [stream.hyp[-context_size:] for stream in streams], - device=device, - dtype=torch.int64, - ) - # decoder_out is of shape (batch_size, 1, decoder_out_dim) - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - for t in range(T): - # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) - current_encoder_out = encoder_out[:, t : t + 1, :] # noqa - - logits = model.joiner( - current_encoder_out.unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - # logits'shape (batch_size, vocab_size) - logits = logits.squeeze(1).squeeze(1) - - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v != blank_id: - streams[i].hyp.append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = torch.tensor( - [stream.hyp[-context_size:] for stream in streams], - device=device, - dtype=torch.int64, - ) - decoder_out = model.decoder( - decoder_input, - need_pad=False, - ) - decoder_out = model.joiner.decoder_proj(decoder_out) - - -def modified_beam_search( - model: nn.Module, - encoder_out: torch.Tensor, - streams: List[Stream], - beam: int = 4, -): - """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. - - Args: - model: - The RNN-T model. - encoder_out: - A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of - the encoder model. - streams: - A list of stream objects. - beam: - Number of active paths during the beam search. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert len(streams) == encoder_out.size(0) - - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - device = next(model.parameters()).device - batch_size = len(streams) - T = encoder_out.size(1) - - B = [stream.hyps for stream in streams] - - encoder_out = model.joiner.encoder_proj(encoder_out) - - for t in range(T): - current_encoder_out = encoder_out[:, t].unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.stack( - [hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0 - ) # (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out is of shape (num_hyps, 1, 1, decoder_output_dim) - - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor - # as index, so we use `to(torch.int64)` below. - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, decoder_out, project_input=False - ) - # logits is of shape (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) - - log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) - - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_ys = hyp.ys[:] - new_token = topk_token_indexes[k] - if new_token != blank_id: - new_ys.append(new_token) - - new_log_prob = topk_log_probs[k] - new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) - B[i].add(new_hyp) - - for i in range(batch_size): - streams[i].hyps = B[i] - - -def fast_beam_search_one_best( - model: nn.Module, - streams: List[Stream], - encoder_out: torch.Tensor, - processed_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, -) -> None: - """It limits the maximum number of symbols per frame to 1. - - A lattice is first obtained using modified beam search, and then - the shortest path within the lattice is used as the final output. - - Args: - model: - An instance of `Transducer`. - streams: - A list of stream objects. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - processed_lens: - A tensor of shape (N,) containing the number of processed frames - in `encoder_out` before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - """ - assert encoder_out.ndim == 3 - - context_size = model.decoder.context_size - vocab_size = model.decoder.vocab_size - - B, T, C = encoder_out.shape - assert B == len(streams) - - config = k2.RnntDecodingConfig( - vocab_size=vocab_size, - decoder_history_len=context_size, - beam=beam, - max_contexts=max_contexts, - max_states=max_states, - ) - individual_streams = [] - for i in range(B): - individual_streams.append(streams[i].rnnt_decoding_stream) - decoding_streams = k2.RnntDecodingStreams(individual_streams, config) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - for t in range(T): - # shape is a RaggedShape of shape (B, context) - # contexts is a Tensor of shape (shape.NumElements(), context_size) - shape, contexts = decoding_streams.get_contexts() - # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 - contexts = contexts.to(torch.int64) - # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) - decoder_out = model.decoder(contexts, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - # current_encoder_out is of shape - # (shape.NumElements(), 1, joiner_dim) - # fmt: off - current_encoder_out = torch.index_select( - encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) - ) - # fmt: on - logits = model.joiner( - current_encoder_out.unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - logits = logits.squeeze(1).squeeze(1) - log_probs = logits.log_softmax(dim=-1) - decoding_streams.advance(log_probs) - - decoding_streams.terminate_and_flush_to_streams() - - lattice = decoding_streams.format_output(processed_lens.tolist()) - - best_path = one_best_decoding(lattice) - hyps = get_texts(best_path) - - for i in range(B): - streams[i].hyp = hyps[i] - - -def decode_one_chunk( - model: nn.Module, - streams: List[Stream], - params: AttributeDict, - decoding_graph: Optional[k2.Fsa] = None, -) -> List[int]: - """ - Args: - model: - The Transducer model. - streams: - A list of Stream objects. - params: - It is returned by :func:`get_params`. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or LG, Used - only when --decoding_method is fast_beam_search. - - Returns: - A list of indexes indicating the finished streams. - """ - device = next(model.parameters()).device - - feature_list = [] - feature_len_list = [] - state_list = [] - num_processed_frames_list = [] - - for stream in streams: - # We should first get `stream.num_processed_frames` - # before calling `stream.get_feature_chunk()` - # since `stream.num_processed_frames` would be updated - num_processed_frames_list.append(stream.num_processed_frames) - feature = stream.get_feature_chunk() - feature_len = feature.size(0) - feature_list.append(feature) - feature_len_list.append(feature_len) - state_list.append(stream.states) - - features = pad_sequence( - feature_list, batch_first=True, padding_value=LOG_EPSILON - ).to(device) - feature_lens = torch.tensor(feature_len_list, device=device) - num_processed_frames = torch.tensor( - num_processed_frames_list, device=device - ) - - # Make sure it has at least 1 frame after subsampling - tail_length = params.subsampling_factor + 5 - if features.size(1) < tail_length: - pad_length = tail_length - features.size(1) - feature_lens += pad_length - features = torch.nn.functional.pad( - features, - (0, 0, 0, pad_length), - mode="constant", - value=LOG_EPSILON, - ) - - # Stack states of all streams - states = stack_states(state_list) - - encoder_out, encoder_out_lens, states = model.encoder( - x=features, - x_lens=feature_lens, - states=states, - ) - - if params.decoding_method == "greedy_search": - greedy_search( - model=model, - streams=streams, - encoder_out=encoder_out, - ) - elif params.decoding_method == "modified_beam_search": - modified_beam_search( - model=model, - streams=streams, - encoder_out=encoder_out, - beam=params.beam_size, - ) - elif params.decoding_method == "fast_beam_search": - # feature_len is needed to get partial results. - # The rnnt_decoding_stream for fast_beam_search. - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - processed_lens = ( - num_processed_frames // params.subsampling_factor - + encoder_out_lens - ) - fast_beam_search_one_best( - model=model, - streams=streams, - encoder_out=encoder_out, - processed_lens=processed_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - - # Update cached states of each stream - state_list = unstack_states(states) - for i, s in enumerate(state_list): - streams[i].states = s - - finished_streams = [i for i, stream in enumerate(streams) if stream.done] - return finished_streams - - -def create_streaming_feature_extractor() -> Fbank: - """Create a CPU streaming feature extractor. - - At present, we assume it returns a fbank feature extractor with - fixed options. In the future, we will support passing in the options - from outside. - - Returns: - Return a CPU streaming feature extractor. - """ - opts = FbankOptions() - opts.device = "cpu" - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = 16000 - opts.mel_opts.num_bins = 80 - return Fbank(opts) - - -def decode_dataset( - cuts: CutSet, - model: nn.Module, - params: AttributeDict, - sp: spm.SentencePieceProcessor, - decoding_graph: Optional[k2.Fsa] = None, -): - """Decode dataset. - - Args: - cuts: - Lhotse Cutset containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The Transducer model. - sp: - The BPE model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or LG, Used - only when --decoding_method is fast_beam_search. - - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - device = next(model.parameters()).device - - log_interval = 300 - - fbank = create_streaming_feature_extractor() - - decode_results = [] - streams = [] - for num, cut in enumerate(cuts): - # Each utterance has a Stream. - stream = Stream( - params=params, - cut_id=cut.id, - decoding_graph=decoding_graph, - device=device, - LOG_EPS=LOG_EPSILON, - ) - - stream.states = model.encoder.get_init_states(device=device) - - audio: np.ndarray = cut.load_audio() - # audio.shape: (1, num_samples) - assert len(audio.shape) == 2 - assert audio.shape[0] == 1, "Should be single channel" - assert audio.dtype == np.float32, audio.dtype - # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" - - samples = torch.from_numpy(audio).squeeze(0) - feature = fbank(samples) - stream.set_feature(feature) - stream.ground_truth = cut.supervisions[0].text - - streams.append(stream) - - while len(streams) >= params.num_decode_streams: - finished_streams = decode_one_chunk( - model=model, - streams=streams, - params=params, - decoding_graph=decoding_graph, - ) - - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - streams[i].id, - streams[i].ground_truth.split(), - sp.decode(streams[i].decoding_result()).split(), - ) - ) - del streams[i] - - if num % log_interval == 0: - logging.info(f"Cuts processed until now is {num}.") - - while len(streams) > 0: - finished_streams = decode_one_chunk( - model=model, - streams=streams, - params=params, - decoding_graph=decoding_graph, - ) - - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - streams[i].id, - streams[i].ground_truth.split(), - sp.decode(streams[i].decoding_result()).split(), - ) - ) - del streams[i] - - if params.decoding_method == "greedy_search": - key = "greedy_search" - elif params.decoding_method == "fast_beam_search": - key = ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ) - else: - key = f"beam_size_{params.beam_size}" - - return {key: decode_results} - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - store_transcripts(filename=recog_path, texts=sorted(results)) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "fast_beam_search", - "modified_beam_search", - ) - params.res_dir = params.exp_dir / "streaming" / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-streaming-decode") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - params.device = device - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.eval() - - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - librispeech = LibriSpeechAsrDataModule(args) - - test_clean_cuts = librispeech.test_clean_cuts() - test_other_cuts = librispeech.test_other_cuts() - - test_sets = ["test-clean", "test-other"] - test_cuts = [test_clean_cuts, test_other_cuts] - - for test_set, test_cut in zip(test_sets, test_cuts): - results_dict = decode_dataset( - cuts=test_cut, - model=model, - params=params, - sp=sp, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - torch.manual_seed(20220810) - main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/train.py b/egs/librispeech/ASR/lstm_transducer_stateless/train.py old mode 100755 new mode 100644 index d30fc260a..e69de29bb --- a/egs/librispeech/ASR/lstm_transducer_stateless/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/train.py @@ -1,1157 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo,) -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./lstm_transducer_stateless/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --exp-dir lstm_transducer_stateless/exp \ - --full-libri 1 \ - --max-duration 300 - -# For mix precision training: - -./lstm_transducer_stateless/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir lstm_transducer_stateless/exp \ - --full-libri 1 \ - --max-duration 550 -""" - -import argparse -import copy -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from lstm import RNN -from model import Transducer -from optim import Eden, Eve -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.utils import ( - AttributeDict, - MetricsTracker, - display_and_save_batch, - setup_logger, - str2bool, -) - -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=int, - default=12, - help="Number of RNN encoder layers..", - ) - - parser.add_argument( - "--encoder-dim", - type=int, - default=512, - help="Encoder output dimesion.", - ) - - parser.add_argument( - "--rnn-hidden-size", - type=int, - default=1024, - help="Hidden dim for LSTM layers.", - ) - - parser.add_argument( - "--aux-layer-period", - type=int, - default=0, - help="""Peroid of auxiliary layers used for randomly combined during training. - If set to 0, will not use the random combiner (Default). - You can set a positive integer to use the random combiner, e.g., 3. - """, - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=35, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="lstm_transducer_stateless/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--initial-lr", - type=float, - default=0.003, - help="""The initial learning rate. This value should not need to be - changed.""", - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate decreases. - We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=10, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=4000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=20, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=100, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - parser.add_argument( - "--delay-penalty", - type=float, - default=0.0, - help="""A constant value used to penalize symbol delay, - to encourage streaming models to emit symbols earlier. - See https://github.com/k2-fsa/k2/issues/955 and - https://arxiv.org/pdf/2211.00490.pdf for more details.""", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warm_step for Noam optimizer. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for conformer - "feature_dim": 80, - "subsampling_factor": 4, - "dim_feedforward": 2048, - # parameters for decoder - "decoder_dim": 512, - # parameters for joiner - "joiner_dim": 512, - # parameters for Noam - "model_warm_step": 3000, # arg given to model, not for lrate - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = RNN( - num_features=params.feature_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.encoder_dim, - rnn_hidden_size=params.rnn_hidden_size, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - aux_layer_period=params.aux_layer_period, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=params.encoder_dim, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=params.encoder_dim, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, - warmup: float = 1.0, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute RNN-T loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Conformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - warmup=warmup, - reduction="none", - delay_penalty=params.delay_penalty if warmup >= 2.0 else 0, - ) - simple_loss_is_finite = torch.isfinite(simple_loss) - pruned_loss_is_finite = torch.isfinite(pruned_loss) - is_finite = simple_loss_is_finite & pruned_loss_is_finite - if not torch.all(is_finite): - logging.info( - "Not all losses are finite!\n" - f"simple_loss: {simple_loss}\n" - f"pruned_loss: {pruned_loss}" - ) - display_and_save_batch(batch, params=params, sp=sp) - simple_loss = simple_loss[simple_loss_is_finite] - pruned_loss = pruned_loss[pruned_loss_is_finite] - - # If either all simple_loss or pruned_loss is inf or nan, - # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): - raise ValueError( - "There are too many utterances in this batch " - "leading to inf or nan losses." - ) - - simple_loss = simple_loss.sum() - pruned_loss = pruned_loss.sum() - # after the main warmup step, we keep pruned_loss_scale small - # for the same amount of time (model_warm_step), to avoid - # overwhelming the simple_loss and causing it to diverge, - # in case it had not fully learned the alignment yet. - pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss - ) - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - # info["frames"] is an approximate number for two reasons: - # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 - # (2) If some utterances in the batch lead to inf/nan loss, they - # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) - - # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa - info["utterances"] = feature.size(0) - # averaged input duration in frames over utterances - info["utt_duration"] = feature_lens.sum().item() - # averaged padding proportion over utterances - info["utt_pad_proportion"] = ( - ((feature.size(1) - feature_lens) / feature.size(1)).sum().item() - ) - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - warmup=(params.batch_idx_train / params.model_warm_step), - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - display_and_save_batch(batch, params=params, sp=sp) - raise - - if params.print_diagnostics and batch_idx == 30: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}" - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - if params.full_libri is False: - params.valid_interval = 800 - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank]) - - optimizer = Eve(model.parameters(), lr=params.initial_lr) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - # # overwrite it - # scheduler.base_lrs = [params.initial_lr for _ in scheduler.base_lrs] - # print(scheduler.base_lrs) - - if params.print_diagnostics: - diagnostic = diagnostics.attach_diagnostics(model) - - librispeech = LibriSpeechAsrDataModule(args) - - train_cuts = librispeech.train_clean_100_cuts() - if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 20.0: - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" - ) - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./lstm.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 3) // 2 - 1) // 2 - tokens = sp.encode(c.supervisions[0].text, out_type=str) - - if T < len(tokens): - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - return False - - return True - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = librispeech.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = librispeech.dev_clean_cuts() - valid_cuts += librispeech.dev_other_cuts() - valid_dl = librispeech.valid_dataloaders(valid_cuts) - - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - warmup=0.0 if params.start_epoch == 1 else 1.0, - ) - - scaler = GradScaler(enabled=params.use_fp16) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, - params: AttributeDict, - warmup: float, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - warmup=warmup, - ) - loss.backward() - optimizer.step() - optimizer.zero_grad() - except RuntimeError as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - raise - - -def main(): - parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index bad4e243e..f7e1b5a54 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -185,20 +185,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -295,8 +299,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -474,9 +477,7 @@ def decode_one_batch( ) feature_lens += num_tail_padded_frames - encoder_out, encoder_out_lens, _ = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens, _ = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -535,10 +536,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -700,9 +698,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -735,8 +731,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -789,9 +784,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -826,13 +819,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -860,13 +852,12 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -895,7 +886,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -961,9 +952,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py index 190673638..0ad00cda3 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py @@ -146,20 +146,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -225,8 +229,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) add_model_arguments(parser) @@ -342,9 +345,7 @@ def export_encoder_model_onnx( x = torch.zeros(N, 9, 80, dtype=torch.float32) x_lens = torch.tensor([9], dtype=torch.int64) h = torch.rand(encoder_model.num_encoder_layers, N, encoder_model.d_model) - c = torch.rand( - encoder_model.num_encoder_layers, N, encoder_model.rnn_hidden_size - ) + c = torch.rand(encoder_model.num_encoder_layers, N, encoder_model.rnn_hidden_size) warmup = 1.0 torch.onnx.export( @@ -445,13 +446,9 @@ def export_joiner_model_onnx( - projected_decoder_out: a tensor of shape (N, joiner_dim) """ - encoder_proj_filename = str(joiner_filename).replace( - ".onnx", "_encoder_proj.onnx" - ) + encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") - decoder_proj_filename = str(joiner_filename).replace( - ".onnx", "_decoder_proj.onnx" - ) + decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] @@ -550,13 +547,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -585,13 +581,12 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -620,7 +615,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -694,9 +689,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py index da184b76f..5a8efd718 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py @@ -86,10 +86,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -124,10 +126,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -315,9 +316,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/model.py b/egs/librispeech/ASR/lstm_transducer_stateless2/model.py index fadeb4ac2..4957d14b1 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/model.py @@ -84,9 +84,7 @@ class Transducer(nn.Module): self.decoder_giga = decoder_giga self.joiner_giga = joiner_giga - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) if decoder_giga is not None: @@ -190,9 +188,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py index 410de8d3d..3b471fa85 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py @@ -156,9 +156,7 @@ class Model: assert ret == 0, ret encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() - encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to( - torch.int32 - ) + encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to(torch.int32) hx = torch.from_numpy(ncnn_out2.numpy()).clone() cx = torch.from_numpy(ncnn_out3.numpy()).clone() return encoder_out, encoder_out_lens, hx, cx @@ -200,10 +198,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -286,9 +283,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py index bef0ad760..7d931a286 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py @@ -92,9 +92,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -119,10 +121,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -169,8 +173,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -201,10 +204,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -267,15 +269,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens, _ = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens, _ = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -347,9 +345,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py index e47a05a9e..baff15ea6 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py @@ -144,9 +144,7 @@ class Model: assert ret == 0, ret encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() - encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to( - torch.int32 - ) + encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to(torch.int32) hx = torch.from_numpy(ncnn_out2.numpy()).clone() cx = torch.from_numpy(ncnn_out3.numpy()).clone() return encoder_out, encoder_out_lens, hx, cx @@ -188,10 +186,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -229,9 +226,7 @@ def greedy_search( if decoder_out is None: assert hyp is None, hyp hyp = [blank_id] * context_size - decoder_input = torch.tensor( - hyp, dtype=torch.int32 - ) # (1, context_size) + decoder_input = torch.tensor(hyp, dtype=torch.int32) # (1, context_size) decoder_out = model.run_decoder(decoder_input).squeeze(0) else: assert decoder_out.ndim == 1 @@ -310,9 +305,7 @@ def main(): frames.append(online_fbank.get_frame(num_processed_frames + i)) num_processed_frames += offset frames = torch.cat(frames, dim=0) - encoder_out, encoder_out_lens, hx, cx = model.run_encoder( - frames, states - ) + encoder_out, encoder_out_lens, hx, cx = model.run_encoder(frames, states) states = (hx, cx) hyp, decoder_out = greedy_search( model, encoder_out.squeeze(0), decoder_out, hyp @@ -328,9 +321,7 @@ def main(): frames.append(online_fbank.get_frame(num_processed_frames + i)) num_processed_frames += offset frames = torch.cat(frames, dim=0) - encoder_out, encoder_out_lens, hx, cx = model.run_encoder( - frames, states - ) + encoder_out, encoder_out_lens, hx, cx = model.run_encoder(frames, states) states = (hx, cx) hyp, decoder_out = greedy_search( model, encoder_out.squeeze(0), decoder_out, hyp @@ -343,9 +334,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py index 232d3dd18..b31fefa0a 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py @@ -109,10 +109,12 @@ def get_args(): parser.add_argument( "sound_filename", type=str, - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -147,10 +149,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -199,9 +200,7 @@ class Model: sess_options=self.session_opts, ) - def run_encoder( - self, x, h0, c0 - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def run_encoder(self, x, h0, c0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: x: @@ -258,9 +257,7 @@ class Model: }, )[0] - return self.run_joiner_decoder_proj( - torch.from_numpy(decoder_out).squeeze(1) - ) + return self.run_joiner_decoder_proj(torch.from_numpy(decoder_out).squeeze(1)) def run_joiner( self, @@ -303,11 +300,7 @@ class Model: projected_encoder_out = self.joiner_encoder_proj.run( [self.joiner_encoder_proj.get_outputs()[0].name], - { - self.joiner_encoder_proj.get_inputs()[ - 0 - ].name: encoder_out.numpy() - }, + {self.joiner_encoder_proj.get_inputs()[0].name: encoder_out.numpy()}, )[0] return torch.from_numpy(projected_encoder_out) @@ -326,11 +319,7 @@ class Model: projected_decoder_out = self.joiner_decoder_proj.run( [self.joiner_decoder_proj.get_outputs()[0].name], - { - self.joiner_decoder_proj.get_inputs()[ - 0 - ].name: decoder_out.numpy() - }, + {self.joiner_decoder_proj.get_inputs()[0].name: decoder_out.numpy()}, )[0] return torch.from_numpy(projected_decoder_out) @@ -369,9 +358,7 @@ def greedy_search( if decoder_out is None: assert hyp is None, hyp hyp = [blank_id] * context_size - decoder_input = torch.tensor( - [hyp], dtype=torch.int64 - ) # (1, context_size) + decoder_input = torch.tensor([hyp], dtype=torch.int64) # (1, context_size) decoder_out = model.run_decoder(decoder_input) else: assert decoder_out.shape[0] == 1 @@ -474,9 +461,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py index 5eaaf321f..08a895a75 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py @@ -95,9 +95,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -163,8 +161,7 @@ def get_parser(): "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. " - "Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.", ) parser.add_argument( @@ -238,42 +235,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -645,11 +645,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -692,9 +688,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -707,14 +701,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -725,9 +714,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -958,9 +945,7 @@ def train_one_epoch( f"train/current_{prefix}_", params.batch_idx_train, ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) libri_tot_loss.write_summary( tb_writer, "train/libri_tot_", params.batch_idx_train ) @@ -1006,8 +991,7 @@ def filter_short_and_long_utterances( # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False @@ -1155,9 +1139,7 @@ def run(rank, world_size, args): train_giga_cuts = train_giga_cuts.repeat(times=None) if args.enable_musan: - cuts_musan = load_manifest( - Path(args.manifest_dir) / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") else: cuts_musan = None diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py index 9eee19379..a8d5605fb 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py @@ -182,20 +182,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -290,8 +294,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -386,9 +389,7 @@ def decode_one_batch( ) feature_lens += num_tail_padded_frames - encoder_out, encoder_out_lens, _ = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens, _ = model.encoder(x=feature, x_lens=feature_lens) if params.decoding_method == "fast_beam_search": res = fast_beam_search_one_best( @@ -441,10 +442,7 @@ def decode_one_batch( nbest_scale=params.nbest_scale, return_timestamps=True, ) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: res = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -522,9 +520,7 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[ - str, List[Tuple[str, List[str], List[str], List[float], List[float]]] -]: +) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]: """Decode dataset. Args: @@ -599,9 +595,7 @@ def decode_dataset( cut_ids, hyps, texts, timestamps_hyp, timestamps_ref ): ref_words = ref_text.split() - this_batch.append( - (cut_id, ref_words, hyp_words, time_ref, time_hyp) - ) + this_batch.append((cut_id, ref_words, hyp_words, time_ref, time_hyp)) results[name].extend(this_batch) @@ -610,9 +604,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -650,8 +642,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -678,9 +669,7 @@ def save_results( note = "" logging.info(s) - s = "\nFor {}, symbol-delay of different settings are:\n".format( - test_set_name - ) + s = "\nFor {}, symbol-delay of different settings are:\n".format(test_set_name) note = "\tbest for {}".format(test_set_name) for key, val in test_set_delays: s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note) @@ -724,9 +713,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -758,13 +745,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -787,13 +773,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -821,7 +806,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -848,9 +833,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/export.py b/egs/librispeech/ASR/lstm_transducer_stateless3/export.py index 212c7bad6..51238f768 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/export.py @@ -122,20 +122,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -172,8 +176,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) add_model_arguments(parser) @@ -281,13 +284,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -310,13 +312,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -344,7 +345,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -380,9 +381,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py index a3443cf0a..180ba8c72 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py @@ -85,10 +85,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -123,10 +125,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -314,9 +315,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py b/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py index 90bc351f4..6e51b85e4 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py @@ -661,9 +661,7 @@ class RandomCombine(nn.Module): self.stddev = stddev self.final_log_weight = ( - torch.tensor( - (final_weight / (1 - final_weight)) * (self.num_inputs - 1) - ) + torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)) .log() .item() ) @@ -760,16 +758,14 @@ class RandomCombine(nn.Module): # final contains self.num_inputs - 1 in all elements final = torch.full((num_frames,), self.num_inputs - 1, device=device) # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. # noqa - nonfinal = torch.randint( - self.num_inputs - 1, (num_frames,), device=device - ) + nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device) indexes = torch.where( torch.rand(num_frames, device=device) < final_prob, final, nonfinal ) - ans = torch.nn.functional.one_hot( - indexes, num_classes=self.num_inputs - ).to(dtype=dtype) + ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to( + dtype=dtype + ) return ans def _get_random_mixed_weights( diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py index 0e48fef04..4f8049245 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py @@ -89,9 +89,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -116,10 +118,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -166,8 +170,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -198,10 +201,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -264,15 +266,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens, _ = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens, _ = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -344,9 +342,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py index cfa918ed5..4e9063a40 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py @@ -101,8 +101,9 @@ def get_parser(): "--epoch", type=int, default=40, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( @@ -119,20 +120,24 @@ def get_parser(): "--avg", type=int, default=20, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -199,8 +204,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -359,9 +363,7 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner( - current_encoder_out, decoder_out, project_input=False - ) + logits = model.joiner(current_encoder_out, decoder_out, project_input=False) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -378,9 +380,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -539,9 +539,7 @@ def decode_one_chunk( feature_list, batch_first=True, padding_value=LOG_EPSILON ).to(device) feature_lens = torch.tensor(feature_len_list, device=device) - num_processed_frames = torch.tensor( - num_processed_frames_list, device=device - ) + num_processed_frames = torch.tensor(num_processed_frames_list, device=device) # Make sure it has at least 1 frame after subsampling tail_length = params.subsampling_factor + 5 @@ -583,8 +581,7 @@ def decode_one_chunk( with warnings.catch_warnings(): warnings.simplefilter("ignore") processed_lens = ( - num_processed_frames // params.subsampling_factor - + encoder_out_lens + num_processed_frames // params.subsampling_factor + encoder_out_lens ) fast_beam_search_one_best( model=model, @@ -596,9 +593,7 @@ def decode_one_chunk( max_states=params.max_states, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") # Update cached states of each stream state_list = unstack_states(states) @@ -773,8 +768,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -816,9 +810,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -852,13 +844,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -881,13 +872,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -915,7 +905,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py index 60a5a2be7..a1d19fb73 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py @@ -87,9 +87,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -232,42 +230,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -606,11 +607,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -650,9 +647,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -665,14 +660,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -683,9 +673,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -852,10 +840,7 @@ def train_one_epoch( rank=rank, ) - if ( - batch_idx % params.log_interval == 0 - and not params.print_diagnostics - ): + if batch_idx % params.log_interval == 0 and not params.print_diagnostics: cur_lr = scheduler.get_last_lr()[0] logging.info( f"Epoch {params.cur_epoch}, " @@ -872,9 +857,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if ( batch_idx > 0 @@ -1009,8 +992,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py index 8dd1459ca..fd2a5354a 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py +++ b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py @@ -74,17 +74,18 @@ class LibriSpeechAsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", + description=( + "These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc." + ), ) group.add_argument( "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. " - "Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.", ) group.add_argument( "--manifest-dir", @@ -96,75 +97,91 @@ class LibriSpeechAsrDataModule: "--max-duration", type=int, default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", + help=( + "Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM." + ), ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", + help=( + "When enabled, the batches will come from buckets of " + "similar duration (saves padding frames)." + ), ) group.add_argument( "--num-buckets", type=int, default=30, - help="The number of buckets for the BucketingSampler" - "(you might want to increase it for larger datasets).", + help=( + "The number of buckets for the BucketingSampler" + "(you might want to increase it for larger datasets)." + ), ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", + help=( + "When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding." + ), ) group.add_argument( "--duration-factor", type=float, default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", + help=( + "Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch." + ), ) group.add_argument( "--gap", type=float, default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", + help=( + "The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used." + ), ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", + help=( + "When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available." + ), ) group.add_argument( "--shuffle", type=str2bool, default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", + help=( + "When enabled (=default), the examples will be shuffled for each epoch." + ), ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", + help=( + "When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it." + ), ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that " - "collect the batches.", + help="The number of training dataloader workers that collect the batches.", ) group.add_argument( @@ -178,18 +195,22 @@ class LibriSpeechAsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", + help=( + "Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp." + ), ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", + help=( + "When enabled, select noise from MUSAN and mix it" + "with training dataset. " + ), ) def train_dataloaders( @@ -208,20 +229,16 @@ class LibriSpeechAsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "cuts_musan.json.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.json.gz") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - f"Using cut concatenation with duration factor " + "Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -236,9 +253,7 @@ class LibriSpeechAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -281,9 +296,7 @@ class LibriSpeechAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -340,9 +353,7 @@ class LibriSpeechAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: @@ -389,23 +400,17 @@ class LibriSpeechAsrDataModule: @lru_cache() def train_clean_100_cuts(self) -> CutSet: logging.info("About to get train-clean-100 cuts") - return load_manifest( - self.args.manifest_dir / "cuts_train-clean-100.json.gz" - ) + return load_manifest(self.args.manifest_dir / "cuts_train-clean-100.json.gz") @lru_cache() def train_clean_360_cuts(self) -> CutSet: logging.info("About to get train-clean-360 cuts") - return load_manifest( - self.args.manifest_dir / "cuts_train-clean-360.json.gz" - ) + return load_manifest(self.args.manifest_dir / "cuts_train-clean-360.json.gz") @lru_cache() def train_other_500_cuts(self) -> CutSet: logging.info("About to get train-other-500 cuts") - return load_manifest( - self.args.manifest_dir / "cuts_train-other-500.json.gz" - ) + return load_manifest(self.args.manifest_dir / "cuts_train-other-500.json.gz") @lru_cache() def dev_clean_cuts(self) -> CutSet: diff --git a/egs/librispeech/ASR/pruned2_knowledge/beam_search.py b/egs/librispeech/ASR/pruned2_knowledge/beam_search.py index 2e9bf3e0b..785a8f097 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/beam_search.py +++ b/egs/librispeech/ASR/pruned2_knowledge/beam_search.py @@ -172,9 +172,9 @@ def greedy_search( y = logits.argmax().item() if y != blank_id: hyp.append(y) - decoder_input = torch.tensor( - [hyp[-context_size:]], device=device - ).reshape(1, context_size) + decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( + 1, context_size + ) decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.joiner.decoder_proj(decoder_out) @@ -302,9 +302,7 @@ class HypothesisList(object): key = hyp.key if key in self: old_hyp = self._data[key] # shallow copy - torch.logaddexp( - old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob - ) + torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) else: self._data[key] = hyp @@ -320,9 +318,7 @@ class HypothesisList(object): Return the hypothesis that has the largest `log_prob`. """ if length_norm: - return max( - self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) - ) + return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) @@ -496,9 +492,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) diff --git a/egs/librispeech/ASR/pruned2_knowledge/conformer.py b/egs/librispeech/ASR/pruned2_knowledge/conformer.py index 295a35204..3b6d0549d 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/conformer.py +++ b/egs/librispeech/ASR/pruned2_knowledge/conformer.py @@ -18,10 +18,10 @@ import math import warnings from typing import Optional, Tuple -from sampling import create_knowledge_base, KnowledgeBaseLookup import torch from encoder_interface import EncoderInterface +from sampling import KnowledgeBaseLookup, create_knowledge_base from scaling import ( ActivationBalancer, BasicNorm, @@ -73,9 +73,9 @@ class Conformer(EncoderInterface): if subsampling_factor != 4: raise NotImplementedError("Support only 'subsampling_factor=4'.") - - self.knowledge_base = create_knowledge_base(knowledge_M, knowledge_N, - knowledge_D) + self.knowledge_base = create_knowledge_base( + knowledge_M, knowledge_N, knowledge_D + ) # self.encoder_embed converts the input of shape (N, T, num_features) # to the shape (N, T//subsampling_factor, d_model). @@ -89,7 +89,7 @@ class Conformer(EncoderInterface): # Pass in a lambda that creates a new ConformerEncoderLayer with these # args. Don't use deepcopy because we need the knowledge_base # to be shared. - encoder_layer_fn = lambda: ConformerEncoderLayer( + encoder_layer_fn = lambda: ConformerEncoderLayer( # noqa: E731 self.knowledge_base, d_model, nhead, @@ -100,7 +100,7 @@ class Conformer(EncoderInterface): knowledge_M, knowledge_N, knowledge_D, - knowledge_K + knowledge_K, ) self.encoder = ConformerEncoder(encoder_layer_fn, num_encoder_layers) @@ -187,9 +187,7 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -209,10 +207,14 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.lookup = KnowledgeBaseLookup(knowledge_M, knowledge_N, - knowledge_D, knowledge_K, - d_model, - knowledge_base) + self.lookup = KnowledgeBaseLookup( + knowledge_M, + knowledge_N, + knowledge_D, + knowledge_K, + d_model, + knowledge_base, + ) self.norm_final = BasicNorm(d_model) @@ -311,9 +313,7 @@ class ConformerEncoder(nn.Module): def __init__(self, encoder_layer_fn, num_layers: int) -> None: super().__init__() - self.layers = nn.ModuleList( - [encoder_layer_fn() for i in range(num_layers)] - ) + self.layers = nn.ModuleList([encoder_layer_fn() for i in range(num_layers)]) self.num_layers = num_layers def forward( @@ -367,9 +367,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -384,9 +382,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vecotr and `j` means the @@ -661,9 +657,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -732,33 +728,25 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor" + " instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -795,9 +783,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -805,13 +791,9 @@ class RelPositionMultiheadAttention(nn.Module): ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -845,13 +827,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -874,9 +852,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/pruned2_knowledge/decode.py b/egs/librispeech/ASR/pruned2_knowledge/decode.py index b4a9af55a..65da19f27 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/decode.py +++ b/egs/librispeech/ASR/pruned2_knowledge/decode.py @@ -76,11 +76,7 @@ from beam_search import ( ) from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import ( AttributeDict, setup_logger, @@ -98,16 +94,19 @@ def get_parser(): "--epoch", type=int, default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -186,8 +185,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -245,9 +243,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -262,10 +258,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -309,11 +302,7 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -385,9 +374,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -419,8 +406,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned2_knowledge/decoder.py b/egs/librispeech/ASR/pruned2_knowledge/decoder.py index b6d94aaf1..0b9c886c7 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/decoder.py +++ b/egs/librispeech/ASR/pruned2_knowledge/decoder.py @@ -90,9 +90,7 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned2_knowledge/decoder2.py b/egs/librispeech/ASR/pruned2_knowledge/decoder2.py index db51fb1cd..2ca76a30c 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/decoder2.py +++ b/egs/librispeech/ASR/pruned2_knowledge/decoder2.py @@ -14,12 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + import torch import torch.nn as nn import torch.nn.functional as F -from torch import Tensor -from typing import Optional from subsampling import ScaledConv1d +from torch import Tensor class Decoder(nn.Module): @@ -90,9 +91,7 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output @@ -102,7 +101,6 @@ class Decoder(nn.Module): return embedding_out - class ScaledEmbedding(nn.Module): r"""A simple lookup table that stores embeddings of a fixed dictionary and size. @@ -171,8 +169,13 @@ class ScaledEmbedding(nn.Module): [ 0.0000, 0.0000, 0.0000], [-0.1655, 0.9897, 0.0635]]]) """ - __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', - 'scale_grad_by_freq', 'sparse'] + __constants__ = [ + "num_embeddings", + "embedding_dim", + "padding_idx", + "scale_grad_by_freq", + "sparse", + ] num_embeddings: int embedding_dim: int @@ -181,34 +184,41 @@ class ScaledEmbedding(nn.Module): weight: Tensor sparse: bool - def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, - scale_grad_by_freq: bool = False, - sparse: bool = False, - scale_speed: float = 5.0) -> None: + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + scale_grad_by_freq: bool = False, + sparse: bool = False, + scale_speed: float = 5.0, + ) -> None: super(ScaledEmbedding, self).__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim if padding_idx is not None: if padding_idx > 0: - assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' + assert ( + padding_idx < self.num_embeddings + ), "Padding_idx must be within num_embeddings" elif padding_idx < 0: - assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' + assert ( + padding_idx >= -self.num_embeddings + ), "Padding_idx must be within num_embeddings" padding_idx = self.num_embeddings + padding_idx self.padding_idx = padding_idx self.scale_grad_by_freq = scale_grad_by_freq self.scale_speed = scale_speed - self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() + self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() self.sparse = sparse self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) self.reset_parameters() - - def reset_parameters(self) -> None: nn.init.normal_(self.weight, std=0.05) - nn.init.constant_(self.scale, torch.tensor(1.0/0.05).log() / self.scale_speed) + nn.init.constant_(self.scale, torch.tensor(1.0 / 0.05).log() / self.scale_speed) if self.padding_idx is not None: with torch.no_grad(): @@ -217,22 +227,38 @@ class ScaledEmbedding(nn.Module): def forward(self, input: Tensor) -> Tensor: scale = (self.scale * self.scale_speed).exp() if input.numel() < self.num_embeddings: - return F.embedding( - input, self.weight, self.padding_idx, - None, 2.0, # None, 2.0 relate to normalization - self.scale_grad_by_freq, self.sparse) * scale + return ( + F.embedding( + input, + self.weight, + self.padding_idx, + None, + 2.0, # None, 2.0 relate to normalization + self.scale_grad_by_freq, + self.sparse, + ) + * scale + ) else: return F.embedding( - input, self.weight * scale, self.padding_idx, - None, 2.0, # None, 2.0 relates to normalization - self.scale_grad_by_freq, self.sparse) + input, + self.weight * scale, + self.padding_idx, + None, + 2.0, # None, 2.0 relates to normalization + self.scale_grad_by_freq, + self.sparse, + ) def extra_repr(self) -> str: - s = '{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}, scale={scale}' + s = ( + "{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}," + " scale={scale}" + ) if self.padding_idx is not None: - s += ', padding_idx={padding_idx}' + s += ", padding_idx={padding_idx}" if self.scale_grad_by_freq is not False: - s += ', scale_grad_by_freq={scale_grad_by_freq}' + s += ", scale_grad_by_freq={scale_grad_by_freq}" if self.sparse is not False: - s += ', sparse=True' + s += ", sparse=True" return s.format(**self.__dict__) diff --git a/egs/librispeech/ASR/pruned2_knowledge/export.py b/egs/librispeech/ASR/pruned2_knowledge/export.py index 96d1a30fb..1af05d9c8 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/export.py +++ b/egs/librispeech/ASR/pruned2_knowledge/export.py @@ -64,17 +64,20 @@ def get_parser(): "--epoch", type=int, default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -105,8 +108,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) return parser @@ -174,9 +176,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned2_knowledge/joiner.py b/egs/librispeech/ASR/pruned2_knowledge/joiner.py index 35f75ed2a..68c663b66 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/joiner.py +++ b/egs/librispeech/ASR/pruned2_knowledge/joiner.py @@ -56,9 +56,7 @@ class Joiner(nn.Module): assert encoder_out.shape[:-1] == decoder_out.shape[:-1] if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj( - decoder_out - ) + logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) else: logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/pruned2_knowledge/model.py b/egs/librispeech/ASR/pruned2_knowledge/model.py index 599bf2506..ca8c28af1 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/model.py +++ b/egs/librispeech/ASR/pruned2_knowledge/model.py @@ -63,9 +63,7 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) def forward( @@ -136,9 +134,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned2_knowledge/optim.py b/egs/librispeech/ASR/pruned2_knowledge/optim.py index 432bf8220..76cd4e11e 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/optim.py +++ b/egs/librispeech/ASR/pruned2_knowledge/optim.py @@ -72,17 +72,11 @@ class Eve(Optimizer): if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: - raise ValueError( - "Invalid beta parameter at index 0: {}".format(betas[0]) - ) + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: - raise ValueError( - "Invalid beta parameter at index 1: {}".format(betas[1]) - ) + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) if not 0 <= weight_decay <= 0.1: - raise ValueError( - "Invalid weight_decay value: {}".format(weight_decay) - ) + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) if not 0 < target_rms <= 10.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) defaults = dict( @@ -118,9 +112,7 @@ class Eve(Optimizer): # Perform optimization step grad = p.grad if grad.is_sparse: - raise RuntimeError( - "AdamW does not support sparse gradients" - ) + raise RuntimeError("AdamW does not support sparse gradients") state = self.state[p] @@ -147,7 +139,7 @@ class Eve(Optimizer): # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( + denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( group["eps"] ) @@ -158,9 +150,7 @@ class Eve(Optimizer): if p.numel() > 1: # avoid applying this weight-decay on "scaling factors" # (which are scalar). - is_above_target_rms = p.norm() > ( - target_rms * (p.numel() ** 0.5) - ) + is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) p.mul_(1 - (weight_decay * is_above_target_rms)) p.addcdiv_(exp_avg, denom, value=-step_size) @@ -176,18 +166,14 @@ class LRScheduler(object): def __init__(self, optimizer: Optimizer, verbose: bool = False): # Attach optimizer if not isinstance(optimizer, Optimizer): - raise TypeError( - "{} is not an Optimizer".format(type(optimizer).__name__) - ) + raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) self.optimizer = optimizer self.verbose = verbose for group in optimizer.param_groups: group.setdefault("initial_lr", group["lr"]) - self.base_lrs = [ - group["initial_lr"] for group in optimizer.param_groups - ] + self.base_lrs = [group["initial_lr"] for group in optimizer.param_groups] self.epoch = 0 self.batch = 0 @@ -295,10 +281,9 @@ class Eden(LRScheduler): def get_lr(self): factor = ( - (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 + (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 ) ** -0.25 * ( - ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) - ** -0.25 + ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 ) return [x * factor for x in self.base_lrs] diff --git a/egs/librispeech/ASR/pruned2_knowledge/sampling.py b/egs/librispeech/ASR/pruned2_knowledge/sampling.py index 7b05e2f00..8cc930927 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/sampling.py +++ b/egs/librispeech/ASR/pruned2_knowledge/sampling.py @@ -3,32 +3,29 @@ # This was copied from /ceph-dan/torch-sampling/torch_sampling/sampling_ref.py, # its git history is there. -import timeit -import torch -from torch import Tensor -from torch import nn -from torch.cuda.amp import GradScaler, custom_fwd, custom_bwd -from typing import Tuple, Optional -from scaling import ScaledLinear import random +import timeit +from typing import Optional, Tuple + +import torch +from scaling import ScaledLinear +from torch import Tensor, nn +from torch.cuda.amp import GradScaler, custom_bwd, custom_fwd from torch_scheduled_sampling import sample_combined # The main exports of this file are the module KnowledgeBaseLookup and the # function create_knowledge_base. - - - - def create_knowledge_base(M: int, N: int, D: int) -> nn.Parameter: std = 0.1 - a = (3 ** 0.5) * std # this sqrt(3) thing is intended to get variance of - # 0.1 from uniform distribution - ans = nn.Parameter(torch.ones(M ** N, D)) + a = (3**0.5) * std # this sqrt(3) thing is intended to get variance of + # 0.1 from uniform distribution + ans = nn.Parameter(torch.ones(M**N, D)) nn.init.uniform_(ans, -a, a) return ans + def join_indexes(indexes: Tensor, M: int) -> Tensor: """ Combines N-tuples of indexes into single indexes that can be used for @@ -47,9 +44,9 @@ def join_indexes(indexes: Tensor, M: int) -> Tensor: # Note, we don't use this, we -def weighted_matrix_lookup(weights: Tensor, - indexes: Tensor, - knowledge_base: Tensor) -> Tensor: +def weighted_matrix_lookup( + weights: Tensor, indexes: Tensor, knowledge_base: Tensor +) -> Tensor: """ Weighted combination of specified rows of a matrix. weights: Tensor of shape (*, K), can contain any value but probably in [0..1]. @@ -65,9 +62,9 @@ def weighted_matrix_lookup(weights: Tensor, # simpler but less memory-efficient implementation lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten()) D = knowledge_base.shape[-1] - weights = weights.unsqueeze(-2) # (*, 1, K) - lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) - ans = torch.matmul(weights, lookup) # ans: (*, 1, D) + weights = weights.unsqueeze(-2) # (*, 1, K) + lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) + ans = torch.matmul(weights, lookup) # ans: (*, 1, D) ans = ans.squeeze(-2) assert list(ans.shape) == list(weights.shape[:-2]) + [D] return ans @@ -76,7 +73,9 @@ def weighted_matrix_lookup(weights: Tensor, class WeightedMatrixLookupFunction(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, weights: Tensor, indexes: Tensor, knowledge_base: Tensor) -> Tensor: + def forward( + ctx, weights: Tensor, indexes: Tensor, knowledge_base: Tensor + ) -> Tensor: """ Weighted combination of specified rows of a matrix. weights: Tensor of shape (*, K), can contain any value but probably in [0..1]. @@ -88,15 +87,16 @@ class WeightedMatrixLookupFunction(torch.autograd.Function): """ if random.random() < 0.001: print("dtype[1] = ", weights.dtype) - ctx.save_for_backward(weights.detach(), indexes.detach(), - knowledge_base.detach()) + ctx.save_for_backward( + weights.detach(), indexes.detach(), knowledge_base.detach() + ) with torch.no_grad(): lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten()) D = knowledge_base.shape[-1] - weights = weights.unsqueeze(-2) # (*, 1, K) - lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) - ans = torch.matmul(weights, lookup) # ans: (*, 1, D) - ans = ans.squeeze(-2) #(*, D) + weights = weights.unsqueeze(-2) # (*, 1, K) + lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) + ans = torch.matmul(weights, lookup) # ans: (*, 1, D) + ans = ans.squeeze(-2) # (*, D) return ans @staticmethod @@ -107,7 +107,7 @@ class WeightedMatrixLookupFunction(torch.autograd.Function): knowledge_base.requires_grad = True dtype = ans_grad.dtype ans_grad = ans_grad.to(weights.dtype) - assert weights.requires_grad == False + assert weights.requires_grad is False D = knowledge_base.shape[-1] with torch.enable_grad(): # we'll use torch's autograd to differentiate this operation, which @@ -115,16 +115,19 @@ class WeightedMatrixLookupFunction(torch.autograd.Function): # We don't save `lookup` because it's large, that is the reason # we override Torch autograd. lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten()) - lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) - weights = weights.unsqueeze(-1) # (*, K, 1) + lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) + weights = weights.unsqueeze(-1) # (*, K, 1) # forward pass: was: ## ans = torch.matmul(weights, lookup) ## ans: (*, 1, D) ## ans = ans.squeeze(-2) # ans, ans_grad: (*, D) - weights_grad = torch.matmul(lookup, # (*, K, D) - ans_grad.unsqueeze(-1)) # (*, D, 1) - weights_grad = weights_grad.squeeze(-1) # (*, K, 1) -> (*, K) - lookup_grad = weights * ans_grad.unsqueeze(-2) # (*, K, 1) * (*, 1, D) = (*, K, D) + weights_grad = torch.matmul( + lookup, ans_grad.unsqueeze(-1) # (*, K, D) + ) # (*, D, 1) + weights_grad = weights_grad.squeeze(-1) # (*, K, 1) -> (*, K) + lookup_grad = weights * ans_grad.unsqueeze( + -2 + ) # (*, K, 1) * (*, 1, D) = (*, K, D) lookup.backward(gradient=lookup_grad) return weights_grad.to(dtype), None, knowledge_base.grad.to(dtype) @@ -146,6 +149,7 @@ class PenalizeNegentropyFunction(torch.autograd.Function): Returns: logprobs """ + @staticmethod def forward(ctx, logprobs: Tensor, alpha: float): ctx.save_for_backward(logprobs.detach()) @@ -154,18 +158,23 @@ class PenalizeNegentropyFunction(torch.autograd.Function): @staticmethod def backward(ctx, logprobs_grad: Tensor) -> Tuple[Tensor, None]: - logprobs, = ctx.saved_tensors + (logprobs,) = ctx.saved_tensors with torch.enable_grad(): logprobs.requires_grad = True # `negentropy` is the negative entropy of the average distribution. # distributions. It will be <= 0. - l = logprobs.reshape(-1, logprobs.shape[-1]) + l = logprobs.reshape(-1, logprobs.shape[-1]) # noqa: E741 scale = ctx.alpha * l.shape[0] avg_dist = l.exp().mean(dim=0) negentropy = (avg_dist * (avg_dist + 1.0e-20).log()).sum() if random.random() < 0.0005: negentropy_individual = (l * l.exp()).sum(dim=-1).mean() - print("Negentropy[individual,combined] = ", negentropy_individual.item(), ", ", negentropy.item()) + print( + "Negentropy[individual,combined] = ", + negentropy_individual.item(), + ", ", + negentropy.item(), + ) loss = negentropy * scale loss.backward() return logprobs_grad + logprobs.grad, None @@ -183,18 +192,23 @@ class KnowledgeBaseLookup(nn.Module): embedding_dim: the dimension to project from and to, e.g. the d_model of the conformer. """ - def __init__(self, M: int, N: int, D: int, - K: int, embedding_dim: int, - knowledge_base: nn.Parameter, - negentropy_penalty: float = 0.001): + + def __init__( + self, + M: int, + N: int, + D: int, + K: int, + embedding_dim: int, + knowledge_base: nn.Parameter, + negentropy_penalty: float = 0.001, + ): super(KnowledgeBaseLookup, self).__init__() self.knowledge_base = knowledge_base # shared! - self.in_proj = ScaledLinear(embedding_dim, M * N, - initial_scale=1.0) + self.in_proj = ScaledLinear(embedding_dim, M * N, initial_scale=1.0) # initial_scale = 4.0 because the knowlege_base activations are # quite small -- if we use our optimizer they'll have stddev <= 0.1. - self.out_proj = ScaledLinear(D, embedding_dim, - initial_scale = 4.0) + self.out_proj = ScaledLinear(D, embedding_dim, initial_scale=4.0) self.M = M self.N = N self.K = K @@ -210,14 +224,14 @@ class KnowledgeBaseLookup(nn.Module): # TODO: later we can try multiplying by a projection of x or something like that. """ - x = self.in_proj(x) # now (*, M*N) - x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M) - x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M) + x = self.in_proj(x) # now (*, M*N) + x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M) + x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M) x = PenalizeNegentropyFunction.apply(x, self.negentropy_penalty) _, indexes, weights = sample_combined(x, self.K, input_is_log=True) - x = weighted_matrix_lookup(weights, indexes, self.knowledge_base) # now (*, D) - x = self.out_proj(x) # now (*, self.embedding_dim) + x = weighted_matrix_lookup(weights, indexes, self.knowledge_base) # now (*, D) + x = self.out_proj(x) # now (*, self.embedding_dim) return x @@ -237,38 +251,44 @@ def _test_knowledge_base_lookup(): x.requires_grad = True y = m(x) assert y.shape == x.shape - y.sum().backward() # make sure backward doesn't crash.. + y.sum().backward() # make sure backward doesn't crash.. print("y = ", y) print("x.grad = ", x.grad) print("knowlege_base.grad norm = ", knowledge_base.grad.norm()) dtype = torch.float32 - device = torch.device('cuda') - train_pairs = [ (torch.randn(B, T, E, device=device, dtype=dtype), torch.randn(B, T, E, device=device, dtype=dtype)) for _ in range(10) ] + device = torch.device("cuda") + train_pairs = [ + ( + torch.randn(B, T, E, device=device, dtype=dtype), + torch.randn(B, T, E, device=device, dtype=dtype), + ) + for _ in range(10) + ] from optim import Eve + optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04) m = m.to(device).to(dtype) - start = timeit.default_timer() -# Epoch 0, batch 0, loss 1.0109944343566895 -# Epoch 10, batch 0, loss 1.0146660804748535 -# Epoch 20, batch 0, loss 1.0119813680648804 -# Epoch 30, batch 0, loss 1.0105408430099487 -# Epoch 40, batch 0, loss 1.0077732801437378 -# Epoch 50, batch 0, loss 1.0050103664398193 -# Epoch 60, batch 0, loss 1.0033129453659058 -# Epoch 70, batch 0, loss 1.0014232397079468 -# Epoch 80, batch 0, loss 0.9977912306785583 -# Epoch 90, batch 0, loss 0.8274348974227905 -# Epoch 100, batch 0, loss 0.3368612825870514 -# Epoch 110, batch 0, loss 0.11323091387748718 -# Time taken: 17.591704960912466 + # Epoch 0, batch 0, loss 1.0109944343566895 + # Epoch 10, batch 0, loss 1.0146660804748535 + # Epoch 20, batch 0, loss 1.0119813680648804 + # Epoch 30, batch 0, loss 1.0105408430099487 + # Epoch 40, batch 0, loss 1.0077732801437378 + # Epoch 50, batch 0, loss 1.0050103664398193 + # Epoch 60, batch 0, loss 1.0033129453659058 + # Epoch 70, batch 0, loss 1.0014232397079468 + # Epoch 80, batch 0, loss 0.9977912306785583 + # Epoch 90, batch 0, loss 0.8274348974227905 + # Epoch 100, batch 0, loss 0.3368612825870514 + # Epoch 110, batch 0, loss 0.11323091387748718 + # Time taken: 17.591704960912466 for epoch in range(150): - for n, (x,y) in enumerate(train_pairs): + for n, (x, y) in enumerate(train_pairs): y_out = m(x) - loss = ((y_out - y)**2).mean() * 100.0 + loss = ((y_out - y) ** 2).mean() * 100.0 if n % 10 == 0 and epoch % 10 == 0: print(f"Epoch {epoch}, batch {n}, loss {loss.item()}") loss.backward() @@ -276,7 +296,8 @@ def _test_knowledge_base_lookup(): optimizer.zero_grad() stop = timeit.default_timer() - print('Time taken: ', stop - start) + print("Time taken: ", stop - start) + def _test_knowledge_base_lookup_autocast(): K = 16 @@ -294,14 +315,21 @@ def _test_knowledge_base_lookup_autocast(): x.requires_grad = True y = m(x) assert y.shape == x.shape - y.sum().backward() # make sure backward doesn't crash.. + y.sum().backward() # make sure backward doesn't crash.. print("y = ", y) print("x.grad = ", x.grad) print("knowlege_base.grad norm = ", knowledge_base.grad.norm()) - device = torch.device('cuda') - train_pairs = [ (torch.randn(B, T, E, device=device), torch.randn(B, T, E, device=device)) for _ in range(10) ] + device = torch.device("cuda") + train_pairs = [ + ( + torch.randn(B, T, E, device=device), + torch.randn(B, T, E, device=device), + ) + for _ in range(10) + ] from optim import Eve + optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04) m = m.to(device) @@ -309,12 +337,11 @@ def _test_knowledge_base_lookup_autocast(): start = timeit.default_timer() - for epoch in range(150): - for n, (x,y) in enumerate(train_pairs): + for n, (x, y) in enumerate(train_pairs): y_out = m(x) with torch.cuda.amp.autocast(enabled=True): - loss = ((y_out - y)**2).mean() * 100.0 + loss = ((y_out - y) ** 2).mean() * 100.0 if n % 10 == 0 and epoch % 10 == 0: print(f"Epoch {epoch}, batch {n}, loss {loss.item()}") scaler.scale(loss).backward() @@ -323,10 +350,9 @@ def _test_knowledge_base_lookup_autocast(): optimizer.zero_grad() stop = timeit.default_timer() - print('Time taken: ', stop - start) + print("Time taken: ", stop - start) - -if __name__ == '__main__': +if __name__ == "__main__": _test_knowledge_base_lookup() _test_knowledge_base_lookup_autocast() diff --git a/egs/librispeech/ASR/pruned2_knowledge/scaling.py b/egs/librispeech/ASR/pruned2_knowledge/scaling.py index f726c2583..527c735eb 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/scaling.py +++ b/egs/librispeech/ASR/pruned2_knowledge/scaling.py @@ -18,11 +18,11 @@ import collections from itertools import repeat from typing import Optional, Tuple -from torch.cuda.amp import custom_fwd, custom_bwd import torch import torch.nn as nn from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd def _ntuple(n): @@ -79,9 +79,7 @@ class ActivationBalancerFunction(torch.autograd.Function): below_threshold = mean_abs < min_abs above_threshold = mean_abs > max_abs - ctx.save_for_backward( - factor, xgt0, below_threshold, above_threshold - ) + ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) ctx.max_factor = max_factor ctx.sum_dims = sum_dims return x @@ -149,8 +147,7 @@ class BasicNorm(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels scales = ( - torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) - + self.eps.exp() + torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps.exp() ) ** -0.5 return x * scales @@ -182,11 +179,7 @@ class ScaledLinear(nn.Linear): """ def __init__( - self, - *args, - initial_scale: float = 1.0, - initial_speed: float = 1.0, - **kwargs + self, *args, initial_scale: float = 1.0, initial_speed: float = 1.0, **kwargs ): super(ScaledLinear, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() @@ -202,12 +195,12 @@ class ScaledLinear(nn.Linear): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -218,19 +211,13 @@ class ScaledLinear(nn.Linear): return None if self.bias is None else self.bias * self.bias_scale.exp() def forward(self, input: Tensor) -> Tensor: - return torch.nn.functional.linear( - input, self.get_weight(), self.get_bias() - ) + return torch.nn.functional.linear(input, self.get_weight(), self.get_bias()) class ScaledConv1d(nn.Conv1d): # See docs for ScaledLinear def __init__( - self, - *args, - initial_scale: float = 1.0, - initial_speed: float = 1.0, - **kwargs + self, *args, initial_scale: float = 1.0, initial_speed: float = 1.0, **kwargs ): super(ScaledConv1d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() @@ -245,12 +232,12 @@ class ScaledConv1d(nn.Conv1d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -290,11 +277,7 @@ class ScaledConv1d(nn.Conv1d): class ScaledConv2d(nn.Conv2d): # See docs for ScaledLinear def __init__( - self, - *args, - initial_scale: float = 1.0, - initial_speed: float = 1.0, - **kwargs + self, *args, initial_scale: float = 1.0, initial_speed: float = 1.0, **kwargs ): super(ScaledConv2d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() @@ -309,12 +292,12 @@ class ScaledConv2d(nn.Conv2d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -653,9 +636,7 @@ def _test_activation_balancer_sign(): def _test_activation_balancer_magnitude(): magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( - -1 - ) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True m = ActivationBalancer( @@ -685,8 +666,8 @@ def _test_basic_norm(): y = m(x) assert y.shape == x.shape - x_rms = (x ** 2).mean().sqrt() - y_rms = (y ** 2).mean().sqrt() + x_rms = (x**2).mean().sqrt() + y_rms = (y**2).mean().sqrt() print("x rms = ", x_rms) print("y rms = ", y_rms) assert y_rms < x_rms diff --git a/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py b/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py index 6293e081a..3f21133a0 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py +++ b/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py @@ -15,21 +15,23 @@ # limitations under the License. +from typing import Optional, Tuple + import torch import torch.nn as nn from torch import Tensor -from typing import Tuple, Optional - -def _activation_balancer_loss(mean_pos: Tensor, - mean_neg: Tensor, - min_positive: float, # e.g. 0.05 - max_positive: float, # e.g. 0.95 - max_factor: float, # e.g. 0.01 - min_abs: float, # e.g. 0.2 - max_abs: float, # e.g. 100.0 - eps: float = 1.0e-10): +def _activation_balancer_loss( + mean_pos: Tensor, + mean_neg: Tensor, + min_positive: float, # e.g. 0.05 + max_positive: float, # e.g. 0.95 + max_factor: float, # e.g. 0.01 + min_abs: float, # e.g. 0.2 + max_abs: float, # e.g. 100.0 + eps: float = 1.0e-10, +): """ Returns a loss-function for the ActivationBalancer module. This loss function is not exposed to the user but is used internally, and eventually @@ -50,28 +52,32 @@ def _activation_balancer_loss(mean_pos: Tensor, """ loss_parts = [] - x_mean = mean_positive - mean_negative - x_mean_abs = (mean_positive + mean_negative + eps).detach() - x_rel_mean= x_mean / x_mean_abs + x_mean = mean_pos - mean_neg + x_mean_abs = (mean_pos + mean_neg + eps).detach() + x_rel_mean = x_mean / x_mean_abs if min_positive != 0.0: # e.g. x_mean_floor = -0.95 + 0.05 = -0.9 - x_rel_mean_floor = (-(1-min_positive) + min_positive) - min_positive_loss = (x_rel_mean_floor - x_rel_mean).relu().sum() * (1.0 / (2*min_positive)) + x_rel_mean_floor = -(1 - min_positive) + min_positive + min_positive_loss = (x_rel_mean_floor - x_rel_mean).relu().sum() * ( + 1.0 / (2 * min_positive) + ) # this part of the loss would be 1.0 * num_channels if all these constraints were # 100% violated. loss_parts.append(min_positive_loss) if max_positive != 1.0: # e.g. x_mean_floor = -0.05 + 0.95 = 0.8 - x_rel_mean_ceil = - (1.0-max_positive) + max_positive - max_positive_loss = (x_rel_mean - x_rel_mean_ceil).relu().sum() * (1.0 / (1 - x_rel_mean_ceil)) + x_rel_mean_ceil = -(1.0 - max_positive) + max_positive + max_positive_loss = (x_rel_mean - x_rel_mean_ceil).relu().sum() * ( + 1.0 / (1 - x_rel_mean_ceil) + ) # this part of the loss would be 1.0 * num_channels if all these constraints were # 100% violated. loss_parts.append(max_positive_loss) if min_abs != 0.0: - min_abs_loss = min_abs - x_mean_abs).relu().sum() / min_abs + min_abs_loss = (min_abs - x_mean_abs).relu().sum() / min_abs # this part of the loss would be 1.0 * num_channels if all these constraints were # 100% violated. loss_parts.append(min_abs_loss) @@ -82,43 +88,53 @@ def _activation_balancer_loss(mean_pos: Tensor, # 100% violated. loss_parts.append(max_abs_loss) - # the min_positive and 1 - max_positive are "ballast" added to the denom = mean_pos + mean_neg + (min_positive + (1 - max_positive)) - num + # num if min_positive != 0.0: - - + pass class ActivationBalancerFunction(torch.autograd.Function): @staticmethod - def forward(ctx, x: Tensor, - channel_dim: int, - min_positive: float, # e.g. 0.05 - max_positive: float, # e.g. 0.95 - max_factor: float, # e.g. 0.01 - min_abs: float, # e.g. 0.2 - max_abs: float, # e.g. 100.0 + def forward( + ctx, + x: Tensor, + channel_dim: int, + min_positive: float, # e.g. 0.05 + max_positive: float, # e.g. 0.95 + max_factor: float, # e.g. 0.01 + min_abs: float, # e.g. 0.2 + max_abs: float, # e.g. 100.0 ) -> Tensor: if x.requires_grad: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] xgt0 = x > 0 - proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True) - factor1 = ((min_positive - proportion_positive).relu() * (max_factor / min_positive) - if min_positive != 0.0 else 0.0) - factor2 = ((proportion_positive - max_positive).relu() * (max_factor / (max_positive - 1.0)) - if max_positive != 1.0 else 0.0) + proportion_positive = torch.mean( + xgt0.to(x.dtype), dim=sum_dims, keepdim=True + ) + factor1 = ( + (min_positive - proportion_positive).relu() + * (max_factor / min_positive) + if min_positive != 0.0 + else 0.0 + ) + factor2 = ( + (proportion_positive - max_positive).relu() + * (max_factor / (max_positive - 1.0)) + if max_positive != 1.0 + else 0.0 + ) factor = factor1 + factor2 if isinstance(factor, float): factor = torch.zeros_like(proportion_positive) mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True) - below_threshold = (mean_abs < min_abs) - above_threshold = (mean_abs > max_abs) + below_threshold = mean_abs < min_abs + above_threshold = mean_abs > max_abs ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) ctx.max_factor = max_factor @@ -126,11 +142,16 @@ class ActivationBalancerFunction(torch.autograd.Function): return x @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None, None]: + def backward( + ctx, x_grad: Tensor + ) -> Tuple[Tensor, None, None, None, None, None, None]: factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors dtype = x_grad.dtype - scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) * - (xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0)) + scale_factor = ( + (below_threshold.to(dtype) - above_threshold.to(dtype)) + * (xgt0.to(dtype) - 0.5) + * (ctx.max_factor * 2.0) + ) neg_delta_grad = x_grad.abs() * (factor + scale_factor) return x_grad - neg_delta_grad, None, None, None, None, None, None @@ -163,29 +184,30 @@ class BasicNorm(torch.nn.Module): learn_eps: if true, we learn epsilon; if false, we keep it at the initial value. """ - def __init__(self, - num_channels: int, - channel_dim: int = -1, # CAUTION: see documentation. - eps: float = 0.25, - learn_eps: bool = True) -> None: + + def __init__( + self, + num_channels: int, + channel_dim: int = -1, # CAUTION: see documentation. + eps: float = 0.25, + learn_eps: bool = True, + ) -> None: super(BasicNorm, self).__init__() self.num_channels = num_channels self.channel_dim = channel_dim if learn_eps: self.eps = nn.Parameter(torch.tensor(eps).log().detach()) else: - self.register_buffer('eps', torch.tensor(eps).log().detach()) - + self.register_buffer("eps", torch.tensor(eps).log().detach()) def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels - scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) + - self.eps.exp()) ** -0.5 + scales = ( + torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps.exp() + ) ** -0.5 return x * scales - - class ScaledLinear(nn.Linear): """ A modified version of nn.Linear where the parameters are scaled before @@ -207,27 +229,26 @@ class ScaledLinear(nn.Linear): inherited from nn.Linear. For modules with small fan-in, this may be larger than optimal. """ - def __init__(self, *args, - initial_scale: float = 1.0, - **kwargs): + + def __init__(self, *args, initial_scale: float = 1.0, **kwargs): super(ScaledLinear, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() self.weight_scale = nn.Parameter(initial_scale.clone().detach()) if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: - self.register_parameter('bias_scale', None) + self.register_parameter("bias_scale", None) self._reset_parameters() # Overrides the reset_parameters in nn.Linear def _reset_parameters(self): std = 0.01 - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() if self.bias is not None: @@ -237,56 +258,67 @@ class ScaledLinear(nn.Linear): return self.weight * self.weight_scale.exp() def get_bias(self): - return (None if self.bias is None else - self.bias * self.bias_scale.exp()) + return None if self.bias is None else self.bias * self.bias_scale.exp() def forward(self, input: Tensor) -> Tensor: - return torch.nn.functional.linear(input, self.get_weight(), - self.get_bias()) + return torch.nn.functional.linear(input, self.get_weight(), self.get_bias()) class ScaledConv1d(nn.Conv1d): - def __init__(self, *args, - initial_scale=1.0, **kwargs): + def __init__(self, *args, initial_scale=1.0, **kwargs): super(ScaledConv1d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() self.weight_scale = nn.Parameter(initial_scale.clone().detach()) if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: - self.register_parameter('bias_scale', None) + self.register_parameter("bias_scale", None) self._reset_parameters() # Overrides the reset_parameters in base class def _reset_parameters(self): std = 0.01 - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() if self.bias is not None: self.bias_scale += torch.tensor(scale / std).log() - def get_weight(self): return self.weight * self.weight_scale.exp() def get_bias(self): - return (None if self.bias is None else - self.bias * self.bias_scale.exp()) + return None if self.bias is None else self.bias * self.bias_scale.exp() def forward(self, input: Tensor) -> Tensor: F = torch.nn.functional - if self.padding_mode != 'zeros': - return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), - self.get_weight(), self.get_bias(), self.stride, - _single(0), self.dilation, self.groups) - return F.conv1d(input, self.get_weight(), self.get_bias(), self.stride, - self.padding, self.dilation, self.groups) - + if self.padding_mode != "zeros": + return F.conv1d( + F.pad( + input, + self._reversed_padding_repeated_twice, + mode=self.padding_mode, + ), + self.get_weight(), + self.get_bias(), + self.stride, + _single(0), # noqa: F821 + self.dilation, + self.groups, + ) + return F.conv1d( + input, + self.get_weight(), + self.get_bias(), + self.stride, + self.padding, + self.dilation, + self.groups, + ) class ScaledConv2d(nn.Conv2d): @@ -297,45 +329,58 @@ class ScaledConv2d(nn.Conv2d): if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: - self.register_parameter('bias_scale', None) + self.register_parameter("bias_scale", None) self._reset_parameters() # Overrides the reset_parameters in base class def _reset_parameters(self): std = 0.01 - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() if self.bias is not None: self.bias_scale += torch.tensor(scale / std).log() - def get_weight(self): return self.weight * self.weight_scale.exp() def get_bias(self): - return (None if self.bias is None else - self.bias * self.bias_scale.exp()) + return None if self.bias is None else self.bias * self.bias_scale.exp() def _conv_forward(self, input, weight): F = torch.nn.functional - if self.padding_mode != 'zeros': - return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), - weight, self.get_bias(), self.stride, - _pair(0), self.dilation, self.groups) - return F.conv2d(input, weight, self.get_bias(), self.stride, - self.padding, self.dilation, self.groups) + if self.padding_mode != "zeros": + return F.conv2d( + F.pad( + input, + self._reversed_padding_repeated_twice, + mode=self.padding_mode, + ), + weight, + self.get_bias(), + self.stride, + _pair(0), # noqa: F821 + self.dilation, + self.groups, + ) + return F.conv2d( + input, + weight, + self.get_bias(), + self.stride, + self.padding, + self.dilation, + self.groups, + ) def forward(self, input: Tensor) -> Tensor: return self._conv_forward(input, self.get_weight()) - - class ActivationBalancer(torch.nn.Module): """ Modifies the backpropped derivatives of a function to try to encourage, for @@ -364,12 +409,16 @@ class ActivationBalancer(torch.nn.Module): we allow, before we start to modify the derivatives to prevent this. """ - def __init__(self, channel_dim: int, - min_positive: float = 0.05, - max_positive: float = 0.95, - max_factor: float = 0.01, - min_abs: float = 0.2, - max_abs: float = 100.0): + + def __init__( + self, + channel_dim: int, + min_positive: float = 0.05, + max_positive: float = 0.95, + max_factor: float = 0.01, + min_abs: float = 0.2, + max_abs: float = 100.0, + ): super(ActivationBalancer, self).__init__() self.channel_dim = channel_dim self.min_positive = min_positive @@ -379,10 +428,15 @@ class ActivationBalancer(torch.nn.Module): self.max_abs = max_abs def forward(self, x: Tensor) -> Tensor: - return ActivationBalancerFunction.apply(x, self.channel_dim, - self.min_positive, self.max_positive, - self.max_factor, self.min_abs, - self.max_abs) + return ActivationBalancerFunction.apply( + x, + self.channel_dim, + self.min_positive, + self.max_positive, + self.max_factor, + self.min_abs, + self.max_abs, + ) class DoubleSwishFunction(torch.autograd.Function): @@ -400,6 +454,7 @@ class DoubleSwishFunction(torch.autograd.Function): = double_swish(x) * (1-s(x)) + s(x) ... so we just need to remember s(x) but not x itself. """ + @staticmethod def forward(ctx, x: Tensor) -> Tensor: x = x.detach() @@ -411,18 +466,17 @@ class DoubleSwishFunction(torch.autograd.Function): @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: s, y = ctx.saved_tensors - return (y * (1-s) + s) * y_grad + return (y * (1 - s) + s) * y_grad + class DoubleSwish(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: """Return double-swish activation function which is an approximation to Swish(Swish(x)), - that we approximate closely with x * sigmoid(x-1). + that we approximate closely with x * sigmoid(x-1). """ return DoubleSwishFunction.apply(x) - - class ScaledEmbedding(nn.Module): r"""A simple lookup table that stores embeddings of a fixed dictionary and size. @@ -491,8 +545,13 @@ class ScaledEmbedding(nn.Module): [ 0.0000, 0.0000, 0.0000], [-0.1655, 0.9897, 0.0635]]]) """ - __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', - 'scale_grad_by_freq', 'sparse'] + __constants__ = [ + "num_embeddings", + "embedding_dim", + "padding_idx", + "scale_grad_by_freq", + "sparse", + ] num_embeddings: int embedding_dim: int @@ -501,33 +560,40 @@ class ScaledEmbedding(nn.Module): weight: Tensor sparse: bool - def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, - scale_grad_by_freq: bool = False, - sparse: bool = False) -> None: + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + scale_grad_by_freq: bool = False, + sparse: bool = False, + ) -> None: super(ScaledEmbedding, self).__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim if padding_idx is not None: if padding_idx > 0: - assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' + assert ( + padding_idx < self.num_embeddings + ), "Padding_idx must be within num_embeddings" elif padding_idx < 0: - assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' + assert ( + padding_idx >= -self.num_embeddings + ), "Padding_idx must be within num_embeddings" padding_idx = self.num_embeddings + padding_idx self.padding_idx = padding_idx self.scale_grad_by_freq = scale_grad_by_freq - self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() + self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() self.sparse = sparse self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) self.reset_parameters() - - def reset_parameters(self) -> None: std = 0.01 nn.init.normal_(self.weight, std=std) - nn.init.constant_(self.scale, torch.tensor(1.0/std).log()) + nn.init.constant_(self.scale, torch.tensor(1.0 / std).log()) if self.padding_idx is not None: with torch.no_grad(): @@ -537,24 +603,37 @@ class ScaledEmbedding(nn.Module): F = torch.nn.functional scale = self.scale.exp() if input.numel() < self.num_embeddings: - return F.embedding( - input, self.weight, self.padding_idx, - None, 2.0, # None, 2.0 relate to normalization - self.scale_grad_by_freq, self.sparse) * scale + return ( + F.embedding( + input, + self.weight, + self.padding_idx, + None, + 2.0, # None, 2.0 relate to normalization + self.scale_grad_by_freq, + self.sparse, + ) + * scale + ) else: return F.embedding( - input, self.weight * scale, self.padding_idx, - None, 2.0, # None, 2.0 relates to normalization - self.scale_grad_by_freq, self.sparse) + input, + self.weight * scale, + self.padding_idx, + None, + 2.0, # None, 2.0 relates to normalization + self.scale_grad_by_freq, + self.sparse, + ) def extra_repr(self) -> str: - s = '{num_embeddings}, {embedding_dim}, scale={scale}' + s = "{num_embeddings}, {embedding_dim}, scale={scale}" if self.padding_idx is not None: - s += ', padding_idx={padding_idx}' + s += ", padding_idx={padding_idx}" if self.scale_grad_by_freq is not False: - s += ', scale_grad_by_freq={scale_grad_by_freq}' + s += ", scale_grad_by_freq={scale_grad_by_freq}" if self.sparse is not False: - s += ', sparse=True' + s += ", sparse=True" return s.format(**self.__dict__) @@ -565,8 +644,13 @@ def _test_activation_balancer_sign(): x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) x = x.detach() x.requires_grad = True - m = ActivationBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95, - max_factor=0.2, min_abs=0.0) + m = ActivationBalancer( + channel_dim=0, + min_positive=0.05, + max_positive=0.95, + max_factor=0.2, + min_abs=0.0, + ) y_grad = torch.sign(torch.randn(probs.numel(), N)) @@ -576,17 +660,22 @@ def _test_activation_balancer_sign(): print("_test_activation_balancer_sign: y grad = ", y_grad) print("_test_activation_balancer_sign: x grad = ", x.grad) + def _test_activation_balancer_magnitude(): channel_dim = 0 magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True - m = ActivationBalancer(channel_dim=0, - min_positive=0.0, max_positive=1.0, - max_factor=0.2, - min_abs=0.2, max_abs=0.8) + m = ActivationBalancer( + channel_dim=0, + min_positive=0.0, + max_positive=1.0, + max_factor=0.2, + min_abs=0.2, + max_abs=0.8, + ) y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) @@ -621,7 +710,7 @@ def _test_double_swish_deriv(): torch.autograd.gradcheck(m, x) -if __name__ == '__main__': +if __name__ == "__main__": _test_activation_balancer_sign() _test_activation_balancer_magnitude() _test_basic_norm() diff --git a/egs/librispeech/ASR/pruned2_knowledge/train.py b/egs/librispeech/ASR/pruned2_knowledge/train.py index 2f6840166..a60d15c3b 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/train.py +++ b/egs/librispeech/ASR/pruned2_knowledge/train.py @@ -78,9 +78,7 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def get_parser(): @@ -179,42 +177,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -554,23 +555,16 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -733,9 +727,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -835,7 +827,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py index 2d5724d30..1df1650f3 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py @@ -123,20 +123,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -204,8 +208,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -272,9 +275,7 @@ def decode_one_batch( value=LOG_EPS, ) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -289,10 +290,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -338,11 +336,7 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -415,9 +409,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -450,8 +442,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -494,9 +485,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -528,13 +517,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -557,13 +545,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -591,7 +578,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py index 318cd5094..008f40fb1 100644 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py @@ -272,13 +272,9 @@ class Emformer(EncoderInterface): # Caution: We assume the subsampling factor is 4! x_lens = (((x_lens - 1) >> 1) - 1) >> 1 - emformer_out, emformer_out_lens, states = self.model.infer( - x, x_lens, states - ) + emformer_out, emformer_out_lens, states = self.model.infer(x, x_lens, states) - if x.size(1) != ( - self.model.segment_length + self.model.right_context_length - ): + if x.size(1) != (self.model.segment_length + self.model.right_context_length): raise ValueError( "Incorrect input shape." f"{x.size(1)} vs {self.model.segment_length} + " diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py index 2375f5001..81afb523d 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py @@ -89,20 +89,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -133,8 +137,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) add_model_arguments(parser) @@ -170,13 +173,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -199,13 +201,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -233,7 +234,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -273,9 +274,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py index 2f019bcdb..ed6848879 100644 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py @@ -122,9 +122,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py index fed814f19..6b30d3be8 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py @@ -209,42 +209,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -566,11 +569,7 @@ def compute_loss( function enables autograd during computation; when it is False, it disables autograd. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -599,9 +598,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -782,9 +779,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -908,8 +903,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py index 7af9cc3d7..830b37cfb 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py @@ -509,9 +509,9 @@ def greedy_search( y = logits.argmax().item() if y not in (blank_id, unk_id): hyp.append(y) - decoder_input = torch.tensor( - [hyp[-context_size:]], device=device - ).reshape(1, context_size) + decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( + 1, context_size + ) decoder_out = model.decoder(decoder_input, need_pad=False) @@ -670,9 +670,7 @@ class HypothesisList(object): if use_max: old_hyp.log_prob = max(old_hyp.log_prob, hyp.log_prob) else: - torch.logaddexp( - old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob - ) + torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) else: self._data[key] = hyp @@ -688,9 +686,7 @@ class HypothesisList(object): Return the hypothesis that has the largest `log_prob`. """ if length_norm: - return max( - self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) - ) + return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) @@ -892,9 +888,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -1088,9 +1082,7 @@ def beam_search( t = 0 B = HypothesisList() - B.add( - Hypothesis(ys=[blank_id] * context_size, log_prob=0.0), use_max=use_max - ) + B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0), use_max=use_max) max_sym_per_utt = 20000 @@ -1130,9 +1122,7 @@ def beam_search( cached_key += f"-t-{t}" if cached_key not in joint_cache: - logits = model.joiner( - current_encoder_out, decoder_out.unsqueeze(1) - ) + logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1)) # TODO(fangjun): Scale the blank posterior diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 7b6338948..03ad45f49 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -128,11 +128,7 @@ from beam_search import ( ) from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -171,9 +167,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -269,8 +267,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -383,9 +380,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if ( @@ -450,10 +445,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -584,9 +576,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -619,8 +609,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -678,9 +667,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -718,8 +705,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -757,9 +743,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py index 386248554..e522943c0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py @@ -75,9 +75,7 @@ class DecodeStream(object): # encoder.streaming_forward self.done_frames: int = 0 - self.pad_length = ( - params.right_context + 2 - ) * params.subsampling_factor + 3 + self.pad_length = (params.right_context + 2) * params.subsampling_factor + 3 if params.decoding_method == "greedy_search": self.hyp = [params.blank_id] * params.context_size @@ -91,13 +89,11 @@ class DecodeStream(object): ) elif params.decoding_method == "fast_beam_search": # The rnnt_decoding_stream for fast_beam_search. - self.rnnt_decoding_stream: k2.RnntDecodingStream = ( - k2.RnntDecodingStream(decoding_graph) + self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( + decoding_graph ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") @property def done(self) -> bool: @@ -126,13 +122,10 @@ class DecodeStream(object): """Consume chunk_size frames of features""" chunk_length = chunk_size + self.pad_length - ret_length = min( - self.num_frames - self.num_processed_frames, chunk_length - ) + ret_length = min(self.num_frames - self.num_processed_frames, chunk_length) ret_features = self.features[ - self.num_processed_frames : self.num_processed_frames # noqa - + ret_length + self.num_processed_frames : self.num_processed_frames + ret_length # noqa ] self.num_processed_frames += chunk_size diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py index f4355e8a0..72593173c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py @@ -92,9 +92,7 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/export.py b/egs/librispeech/ASR/pruned_transducer_stateless/export.py index b5a151878..64708e524 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/export.py @@ -64,17 +64,20 @@ def get_parser(): "--epoch", type=int, default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help=( + "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." + ), ) parser.add_argument( "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. " + ), ) parser.add_argument( @@ -105,8 +108,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -192,9 +194,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/model.py b/egs/librispeech/ASR/pruned_transducer_stateless/model.py index 73b651b3f..2cca7fa27 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/model.py @@ -130,9 +130,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py index eb95827af..a42b63b9c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py @@ -91,9 +91,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -118,10 +120,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -168,8 +172,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -221,10 +224,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -292,9 +294,7 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -381,9 +381,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py index dcf6dc42f..9e09200a1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py @@ -166,14 +166,10 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk( - num_active_paths - ) + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths) with warnings.catch_warnings(): warnings.simplefilter("ignore") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py index d2cae4f9f..a50b4d4f0 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py @@ -51,11 +51,7 @@ from streaming_beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import ( AttributeDict, setup_logger, @@ -94,9 +90,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -162,8 +160,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -269,9 +266,7 @@ def decode_one_chunk( ) if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -291,9 +286,7 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -349,9 +342,7 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state( - params.left_context, device=device - ) + initial_states = model.encoder.get_init_state(params.left_context, device=device) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -422,9 +413,7 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -460,8 +449,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -533,8 +521,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index 399b11a29..dd0331a60 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -203,42 +203,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -562,9 +565,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -584,9 +585,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -777,9 +776,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -897,8 +894,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False @@ -956,9 +952,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index b7c2010f7..5e9428b60 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -580,9 +580,9 @@ def greedy_search( if y not in (blank_id, unk_id): hyp.append(y) timestamp.append(t) - decoder_input = torch.tensor( - [hyp[-context_size:]], device=device - ).reshape(1, context_size) + decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( + 1, context_size + ) decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.joiner.decoder_proj(decoder_out) @@ -775,9 +775,7 @@ class HypothesisList(object): key = hyp.key if key in self: old_hyp = self._data[key] # shallow copy - torch.logaddexp( - old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob - ) + torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) else: self._data[key] = hyp @@ -793,9 +791,7 @@ class HypothesisList(object): Return the hypothesis that has the largest `log_prob`. """ if length_norm: - return max( - self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) - ) + return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) @@ -990,9 +986,7 @@ def modified_beam_search( logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - log_probs = (logits / temperature).log_softmax( - dim=-1 - ) # (num_hyps, vocab_size) + log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) log_probs.add_(ys_log_probs) @@ -1004,9 +998,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -1676,9 +1668,7 @@ def fast_beam_search_with_nbest_rnn_rescoring( for rnn_scale in rnn_lm_scale_list: key = f"ngram_lm_scale_{n_scale}_rnn_lm_scale_{rnn_scale}" tot_scores = ( - am_scores.values - + n_scale * ngram_lm_scores - + rnn_scale * rnn_lm_scores + am_scores.values + n_scale * ngram_lm_scores + rnn_scale * rnn_lm_scores ) ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) max_indexes = ragged_tot_scores.argmax() @@ -1804,9 +1794,7 @@ def modified_beam_search_ngram_rescoring( logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - log_probs = (logits / temperature).log_softmax( - dim=-1 - ) # (num_hyps, vocab_size) + log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) log_probs.add_(ys_log_probs) vocab_size = log_probs.size(-1) @@ -1816,9 +1804,7 @@ def modified_beam_search_ngram_rescoring( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -1841,9 +1827,7 @@ def modified_beam_search_ngram_rescoring( state_cost = hyp.state_cost # We only keep AM scores in new_hyp.log_prob - new_log_prob = ( - topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale - ) + new_log_prob = topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale new_hyp = Hypothesis( ys=new_ys, log_prob=new_log_prob, state_cost=state_cost @@ -1995,9 +1979,7 @@ def modified_beam_search_rnnlm_shallow_fusion( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) """ for all hyps with a non-blank new token, score this token. It is a little confusing here because this for-loop @@ -2032,10 +2014,7 @@ def modified_beam_search_rnnlm_shallow_fusion( # forward RNNLM to get new states and scores if len(token_list) != 0: tokens_to_score = ( - torch.tensor(token_list) - .to(torch.int64) - .to(device) - .reshape(-1, 1) + torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) ) hs = torch.cat(hs, dim=1).to(device) @@ -2067,9 +2046,7 @@ def modified_beam_search_rnnlm_shallow_fusion( ys.append(new_token) new_timestamp.append(t) - hyp_log_prob += ( - lm_score[new_token] * lm_scale - ) # add the lm score + hyp_log_prob += lm_score[new_token] * lm_scale # add the lm score lm_score = scores[count] state = ( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index bc273d33b..34ff0d7e2 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -214,10 +214,7 @@ class Conformer(EncoderInterface): NOTE: the returned tensors are on the given device. """ - if ( - len(self._init_state) == 2 - and self._init_state[0].size(1) == left_context - ): + if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context: # Note: It is OK to share the init state as it is # not going to be modified by the model return self._init_state @@ -439,9 +436,7 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -459,9 +454,7 @@ class ConformerEncoderLayer(nn.Module): ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) - self.conv_module = ConvolutionModule( - d_model, cnn_module_kernel, causal=causal - ) + self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal) self.norm_final = BasicNorm(d_model) @@ -527,9 +520,7 @@ class ConformerEncoderLayer(nn.Module): src = src + self.dropout(src_att) # convolution module - conv, _ = self.conv_module( - src, src_key_padding_mask=src_key_padding_mask - ) + conv, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask) src = src + self.dropout(conv) # feed forward module @@ -785,9 +776,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() if is_jit_tracing(): @@ -811,9 +800,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x_size_1 * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -1127,9 +1114,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -1198,33 +1185,25 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor" + " instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -1264,23 +1243,15 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d - matrix_bd = torch.matmul( - q_with_bias_v, p - ) # (batch, head, time1, 2*time1-1) + matrix_bd = torch.matmul(q_with_bias_v, p) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd, left_context) - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) if not is_jit_tracing(): assert list(attn_output_weights.size()) == [ @@ -1322,21 +1293,17 @@ class RelPositionMultiheadAttention(nn.Module): ): if attn_mask.size(0) != 1: attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len) - combined_mask = attn_mask | key_padding_mask.unsqueeze( - 1 - ).unsqueeze(2) + combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) else: # attn_mask.shape == (1, tgt_len, src_len) - combined_mask = attn_mask.unsqueeze( - 0 - ) | key_padding_mask.unsqueeze(1).unsqueeze(2) + combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( + 1 + ).unsqueeze(2) attn_output_weights = attn_output_weights.view( bsz, num_heads, tgt_len, src_len ) - attn_output_weights = attn_output_weights.masked_fill( - combined_mask, 0.0 - ) + attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, src_len ) @@ -1355,13 +1322,9 @@ class RelPositionMultiheadAttention(nn.Module): ] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -1498,16 +1461,12 @@ class ConvolutionModule(nn.Module): # manualy padding self.lorder zeros to the left x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) else: - assert ( - not self.training - ), "Cache should be None in training time" + assert not self.training, "Cache should be None in training time" assert cache.size(0) == self.lorder x = torch.cat([cache.permute(1, 2, 0), x], dim=2) if right_context > 0: cache = x.permute(2, 0, 1)[ - -(self.lorder + right_context) : ( # noqa - -right_context - ), + -(self.lorder + right_context) : (-right_context), # noqa ..., ] else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index 979a0e02e..32cd53be3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -132,11 +132,7 @@ from beam_search import ( ) from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -177,9 +173,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -275,8 +273,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -397,9 +394,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -465,10 +460,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -514,11 +506,7 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } elif "fast_beam_search" in params.decoding_method: key = f"beam_{params.beam}_" @@ -608,9 +596,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -643,8 +629,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -700,9 +685,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -740,8 +723,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -779,9 +761,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index ba91302ce..b59928103 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -107,15 +107,11 @@ class Decoder(nn.Module): # This is for exporting to PNNX via ONNX embedding_out = self.embedding(y) else: - embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze( - -1 - ) + embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py index f1a8ea589..90367bd03 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py @@ -51,11 +51,7 @@ import sentencepiece as spm import torch from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import str2bool @@ -87,9 +83,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -120,8 +118,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -173,8 +170,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -222,9 +218,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py index 6a9d08033..1954f4724 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py @@ -60,9 +60,7 @@ class Joiner(nn.Module): assert encoder_out.shape == decoder_out.shape if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj( - decoder_out - ) + logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) else: logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 417c391d9..272d06c37 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -66,9 +66,7 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) def forward( @@ -152,9 +150,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py index 041a81f45..2d7f557ad 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -72,17 +72,11 @@ class Eve(Optimizer): if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: - raise ValueError( - "Invalid beta parameter at index 0: {}".format(betas[0]) - ) + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: - raise ValueError( - "Invalid beta parameter at index 1: {}".format(betas[1]) - ) + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) if not 0 <= weight_decay <= 0.1: - raise ValueError( - "Invalid weight_decay value: {}".format(weight_decay) - ) + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) if not 0 < target_rms <= 10.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) defaults = dict( @@ -118,9 +112,7 @@ class Eve(Optimizer): # Perform optimization step grad = p.grad if grad.is_sparse: - raise RuntimeError( - "AdamW does not support sparse gradients" - ) + raise RuntimeError("AdamW does not support sparse gradients") state = self.state[p] @@ -147,7 +139,7 @@ class Eve(Optimizer): # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( + denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( group["eps"] ) @@ -158,9 +150,7 @@ class Eve(Optimizer): if p.numel() > 1: # avoid applying this weight-decay on "scaling factors" # (which are scalar). - is_above_target_rms = p.norm() > ( - target_rms * (p.numel() ** 0.5) - ) + is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) p.mul_(1 - (weight_decay * is_above_target_rms)) p.addcdiv_(exp_avg, denom, value=-step_size) @@ -180,18 +170,14 @@ class LRScheduler(object): def __init__(self, optimizer: Optimizer, verbose: bool = False): # Attach optimizer if not isinstance(optimizer, Optimizer): - raise TypeError( - "{} is not an Optimizer".format(type(optimizer).__name__) - ) + raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) self.optimizer = optimizer self.verbose = verbose for group in optimizer.param_groups: group.setdefault("initial_lr", group["lr"]) - self.base_lrs = [ - group["initial_lr"] for group in optimizer.param_groups - ] + self.base_lrs = [group["initial_lr"] for group in optimizer.param_groups] self.epoch = 0 self.batch = 0 @@ -299,10 +285,9 @@ class Eden(LRScheduler): def get_lr(self): factor = ( - (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 + (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 ) ** -0.25 * ( - ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) - ** -0.25 + ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 ) return [x * factor for x in self.base_lrs] diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py index f52cb22ab..58de6875f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py @@ -91,9 +91,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -118,10 +120,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -168,8 +172,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -222,10 +225,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -293,9 +295,7 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -382,9 +382,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 8c572a9ef..f671e97b1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -89,9 +89,7 @@ class ActivationBalancerFunction(torch.autograd.Function): below_threshold = mean_abs < min_abs above_threshold = mean_abs > max_abs - ctx.save_for_backward( - factor, xgt0, below_threshold, above_threshold - ) + ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) ctx.max_factor = max_factor ctx.sum_dims = sum_dims return x @@ -137,7 +135,7 @@ class GradientFilterFunction(torch.autograd.Function): eps = 1.0e-20 dim = ctx.batch_dim norm_dims = [d for d in range(x_grad.ndim) if d != dim] - norm_of_batch = (x_grad ** 2).mean(dim=norm_dims, keepdim=True).sqrt() + norm_of_batch = (x_grad**2).mean(dim=norm_dims, keepdim=True).sqrt() median_norm = norm_of_batch.median() cutoff = median_norm * ctx.threshold @@ -229,8 +227,7 @@ class BasicNorm(torch.nn.Module): if not is_jit_tracing(): assert x.shape[self.channel_dim] == self.num_channels scales = ( - torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) - + self.eps.exp() + torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps.exp() ) ** -0.5 return x * scales @@ -282,12 +279,12 @@ class ScaledLinear(nn.Linear): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -301,9 +298,7 @@ class ScaledLinear(nn.Linear): return self.bias * self.bias_scale.exp() def forward(self, input: Tensor) -> Tensor: - return torch.nn.functional.linear( - input, self.get_weight(), self.get_bias() - ) + return torch.nn.functional.linear(input, self.get_weight(), self.get_bias()) class ScaledConv1d(nn.Conv1d): @@ -331,12 +326,12 @@ class ScaledConv1d(nn.Conv1d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -400,12 +395,12 @@ class ScaledConv2d(nn.Conv2d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -476,9 +471,7 @@ class ScaledLSTM(nn.LSTM): setattr(self, scale_name, param) self._scales.append(param) - self.grad_filter = GradientFilter( - batch_dim=1, threshold=grad_norm_threshold - ) + self.grad_filter = GradientFilter(batch_dim=1, threshold=grad_norm_threshold) self._reset_parameters( initial_speed @@ -486,8 +479,8 @@ class ScaledLSTM(nn.LSTM): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std - scale = self.hidden_size ** -0.5 + a = (3**0.5) * std + scale = self.hidden_size**-0.5 v = scale / std for idx, name in enumerate(self._flat_weights_names): if "weight" in name: @@ -559,15 +552,11 @@ class ScaledLSTM(nn.LSTM): """Get scaled weights, and resets their data pointer.""" flat_weights = [] for idx in range(len(self._flat_weights_names)): - flat_weights.append( - self._flat_weights[idx] * self._scales[idx].exp() - ) + flat_weights.append(self._flat_weights[idx] * self._scales[idx].exp()) self._flatten_parameters(flat_weights) return flat_weights - def forward( - self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None - ): + def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None): # This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa # The change for calling `_VF.lstm()` is: # self._flat_weights -> self._get_flat_weights() @@ -915,9 +904,7 @@ def _test_activation_balancer_sign(): def _test_activation_balancer_magnitude(): magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( - -1 - ) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True m = ActivationBalancer( @@ -947,8 +934,8 @@ def _test_basic_norm(): y = m(x) assert y.shape == x.shape - x_rms = (x ** 2).mean().sqrt() - y_rms = (y ** 2).mean().sqrt() + x_rms = (x**2).mean().sqrt() + y_rms = (y**2).mean().sqrt() print("x rms = ", x_rms) print("y rms = ", y_rms) assert y_rms < x_rms @@ -1001,17 +988,18 @@ def _test_grad_filter(): ) print( - "_test_grad_filter: for gradient norms, the first element > median * threshold ", # noqa + "_test_grad_filter: for gradient norms, the first element > median *" + " threshold ", # noqa i % 2 == 1, ) print( "_test_grad_filter: x_out_grad norm = ", - (x_out_grad ** 2).mean(dim=(0, 2)).sqrt(), + (x_out_grad**2).mean(dim=(0, 2)).sqrt(), ) print( "_test_grad_filter: x.grad norm = ", - (x.grad ** 2).mean(dim=(0, 2)).sqrt(), + (x.grad**2).mean(dim=(0, 2)).sqrt(), ) print("_test_grad_filter: w_out_grad = ", w_out_grad) print("_test_grad_filter: w.grad = ", w.grad) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py index 9bcd2f9f9..e6e0fb1c8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py @@ -153,9 +153,7 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner( - current_encoder_out, decoder_out, project_input=False - ) + logits = model.joiner(current_encoder_out, decoder_out, project_input=False) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -172,14 +170,10 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk( - num_active_paths - ) + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths) with warnings.catch_warnings(): warnings.simplefilter("ignore") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py index d76a03946..0139863a1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py @@ -51,11 +51,7 @@ from streaming_beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import ( AttributeDict, setup_logger, @@ -94,9 +90,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -162,8 +160,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -271,9 +268,7 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -293,9 +288,7 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -351,9 +344,7 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state( - params.left_context, device=device - ) + initial_states = model.encoder.get_init_state(params.left_context, device=device) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -425,9 +416,7 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -462,8 +451,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -536,8 +524,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 1947834bf..623bdd51a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -96,9 +96,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -210,8 +208,7 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need to " - "be changed.", + help="The initial learning rate. This value should not need to be changed.", ) parser.add_argument( @@ -234,42 +231,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -634,9 +634,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -649,14 +647,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -667,9 +660,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -837,9 +828,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -963,8 +952,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py index 1df7f9ee5..5e81aef07 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py @@ -27,10 +27,7 @@ from lhotse.dataset import ( K2SpeechRecognitionDataset, SpecAugment, ) -from lhotse.dataset.input_strategies import ( - OnTheFlyFeatures, - PrecomputedFeatures, -) +from lhotse.dataset.input_strategies import OnTheFlyFeatures, PrecomputedFeatures from torch.utils.data import DataLoader from icefall.utils import str2bool @@ -44,59 +41,69 @@ class AsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", + description=( + "These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc." + ), ) group.add_argument( "--max-duration", type=int, default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", + help=( + "Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM." + ), ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", + help=( + "When enabled, the batches will come from buckets of " + "similar duration (saves padding frames)." + ), ) group.add_argument( "--num-buckets", type=int, default=30, - help="The number of buckets for the DynamicBucketingSampler. " - "(you might want to increase it for larger datasets).", + help=( + "The number of buckets for the DynamicBucketingSampler. " + "(you might want to increase it for larger datasets)." + ), ) group.add_argument( "--shuffle", type=str2bool, default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", + help=( + "When enabled (=default), the examples will be shuffled for each epoch." + ), ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", + help=( + "When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it." + ), ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that " - "collect the batches.", + help="The number of training dataloader workers that collect the batches.", ) group.add_argument( @@ -117,18 +124,22 @@ class AsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", + help=( + "Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp." + ), ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", + help=( + "When enabled, select noise from MUSAN and mix it" + "with training dataset. " + ), ) group.add_argument( @@ -142,9 +153,11 @@ class AsrDataModule: "--on-the-fly-feats", type=str2bool, default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available. Used only in dev/test CutSet", + help=( + "When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available. Used only in dev/test CutSet" + ), ) def train_dataloaders( @@ -167,9 +180,7 @@ class AsrDataModule: if cuts_musan is not None: logging.info("Enable MUSAN") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -178,9 +189,7 @@ class AsrDataModule: if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") input_transforms.append( SpecAugment( time_warp_factor=self.args.spec_aug_time_warp_factor, @@ -250,9 +259,7 @@ class AsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py index 5784a78ba..66c8e30ba 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py @@ -79,11 +79,7 @@ from gigaspeech import GigaSpeech from gigaspeech_scoring import asr_text_post_processing from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import ( AttributeDict, setup_logger, @@ -120,9 +116,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -192,8 +190,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -280,9 +277,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -312,10 +307,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -359,21 +351,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } elif params.decoding_method == "fast_beam_search_nbest_oracle": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}_" - f"num_paths_{params.num_paths}_" - f"nbest_scale_{params.nbest_scale}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}_num_paths_{params.num_paths}_nbest_scale_{params.nbest_scale}": hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -446,9 +428,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -481,8 +461,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -532,9 +511,7 @@ def main(): params.suffix += f"-num-paths-{params.num_paths}" params.suffix += f"-nbest-scale-{params.nbest_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -567,8 +544,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index 8025d6be1..d90497e26 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -120,11 +120,7 @@ from beam_search import ( from librispeech import LibriSpeech from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.lexicon import Lexicon from icefall.rnn_lm.model import RnnLmModel from icefall.utils import ( @@ -167,9 +163,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -265,8 +263,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -478,9 +475,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -550,10 +545,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -646,21 +638,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - f"temperature_{params.temperature}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}temperature_{params.temperature}": hyps } elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - f"temperature_{params.temperature}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}temperature_{params.temperature}": hyps } elif params.decoding_method in [ "fast_beam_search_with_nbest_rescoring", @@ -690,12 +672,7 @@ def decode_one_batch( key += f"_ngram_lm_scale_{params.ngram_lm_scale}" return {key: hyps} else: - return { - ( - f"beam_size_{params.beam_size}_" - f"temperature_{params.temperature}" - ): hyps - } + return {f"beam_size_{params.beam_size}_temperature_{params.temperature}": hyps} def decode_dataset( @@ -779,9 +756,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -814,8 +789,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -939,9 +913,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" params.suffix += f"-temperature-{params.temperature}" else: params.suffix += f"-context-{params.context_size}" @@ -981,8 +953,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -1032,15 +1003,10 @@ def main(): word_table=word_table, device=device, ) - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) logging.info(f"G properties_str: {G.properties_str}") rnn_lm_model = None - if ( - params.decoding_method - == "fast_beam_search_with_nbest_rnn_rescoring" - ): + if params.decoding_method == "fast_beam_search_with_nbest_rnn_rescoring": rnn_lm_model = RnnLmModel( vocab_size=params.vocab_size, embedding_dim=params.rnn_lm_embedding_dim, @@ -1065,9 +1031,7 @@ def main(): rnn_lm_model.eval() else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) rnn_lm_model = None else: decoding_graph = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py index 47217ba05..dcf65e937 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py @@ -128,11 +128,7 @@ import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import str2bool @@ -164,9 +160,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -235,8 +233,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -509,13 +506,9 @@ def export_joiner_model_onnx( - projected_decoder_out: a tensor of shape (N, joiner_dim) """ - encoder_proj_filename = str(joiner_filename).replace( - ".onnx", "_encoder_proj.onnx" - ) + encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") - decoder_proj_filename = str(joiner_filename).replace( - ".onnx", "_decoder_proj.onnx" - ) + decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] @@ -616,8 +609,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -715,9 +707,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py index 36f32c6b3..598434f54 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py @@ -52,18 +52,14 @@ class GigaSpeech: ) pattern = re.compile(r"gigaspeech_cuts_XL.([0-9]+).jsonl.gz") - idx_filenames = [ - (int(pattern.search(f).group(1)), f) for f in filenames - ] + idx_filenames = [(int(pattern.search(f).group(1)), f) for f in filenames] idx_filenames = sorted(idx_filenames, key=lambda x: x[0]) sorted_filenames = [f[1] for f in idx_filenames] logging.info(f"Loading {len(sorted_filenames)} splits") - return lhotse.combine( - lhotse.load_manifest_lazy(p) for p in sorted_filenames - ) + return lhotse.combine(lhotse.load_manifest_lazy(p) for p in sorted_filenames) def train_L_cuts(self) -> CutSet: f = self.manifest_dir / "gigaspeech_cuts_L_raw.jsonl.gz" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py index 162f8c7db..108915389 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py @@ -104,10 +104,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -142,10 +144,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -330,9 +331,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py index 7852f84e9..d45f6dadc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py @@ -84,9 +84,7 @@ class Transducer(nn.Module): self.decoder_giga = decoder_giga self.joiner_giga = joiner_giga - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) if decoder_giga is not None: @@ -190,9 +188,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = encoder_out_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py index d03d1d7ef..163d737e3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py @@ -203,9 +203,7 @@ def test_joiner( ) # Now test encoder_proj - joiner_encoder_proj_inputs = { - encoder_proj_input_name: encoder_out.numpy() - } + joiner_encoder_proj_inputs = {encoder_proj_input_name: encoder_out.numpy()} joiner_encoder_proj_out = joiner_encoder_proj_session.run( [encoder_proj_output_name], joiner_encoder_proj_inputs )[0] @@ -214,16 +212,10 @@ def test_joiner( torch_joiner_encoder_proj_out = model.joiner.encoder_proj(encoder_out) assert torch.allclose( joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5 - ), ( - (joiner_encoder_proj_out - torch_joiner_encoder_proj_out) - .abs() - .max() - ) + ), ((joiner_encoder_proj_out - torch_joiner_encoder_proj_out).abs().max()) # Now test decoder_proj - joiner_decoder_proj_inputs = { - decoder_proj_input_name: decoder_out.numpy() - } + joiner_decoder_proj_inputs = {decoder_proj_input_name: decoder_out.numpy()} joiner_decoder_proj_out = joiner_decoder_proj_session.run( [decoder_proj_output_name], joiner_decoder_proj_inputs )[0] @@ -232,11 +224,7 @@ def test_joiner( torch_joiner_decoder_proj_out = model.joiner.decoder_proj(decoder_out) assert torch.allclose( joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5 - ), ( - (joiner_decoder_proj_out - torch_joiner_decoder_proj_out) - .abs() - .max() - ) + ), ((joiner_decoder_proj_out - torch_joiner_decoder_proj_out).abs().max()) @torch.no_grad() @@ -288,9 +276,7 @@ def main(): if __name__ == "__main__": torch.manual_seed(20220727) - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py index ea5d4e674..11597aa49 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py @@ -102,10 +102,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -140,10 +142,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -191,11 +192,7 @@ def greedy_search( projected_encoder_out = joiner_encoder_proj.run( [joiner_encoder_proj.get_outputs()[0].name], - { - joiner_encoder_proj.get_inputs()[ - 0 - ].name: packed_encoder_out.data.numpy() - }, + {joiner_encoder_proj.get_inputs()[0].name: packed_encoder_out.data.numpy()}, )[0] blank_id = 0 # hard-code to 0 @@ -382,9 +379,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py index 19b636a23..849d6cf4e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py @@ -100,9 +100,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -127,10 +129,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -177,8 +181,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -231,10 +234,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -302,9 +304,7 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -391,9 +391,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py index 1e6022b57..85d87f8f2 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py @@ -234,9 +234,7 @@ def scaled_lstm_to_lstm(scaled_lstm: ScaledLSTM) -> nn.LSTM: assert lstm._flat_weights_names == scaled_lstm._flat_weights_names for idx in range(len(scaled_lstm._flat_weights_names)): - scaled_weight = ( - scaled_lstm._flat_weights[idx] * scaled_lstm._scales[idx].exp() - ) + scaled_weight = scaled_lstm._flat_weights[idx] * scaled_lstm._scales[idx].exp() lstm._flat_weights[idx].data.copy_(scaled_weight) return lstm @@ -251,12 +249,10 @@ def get_submodule(model, target): mod: torch.nn.Module = model for item in atoms: if not hasattr(mod, item): - raise AttributeError( - mod._get_name() + " has no " "attribute `" + item + "`" - ) + raise AttributeError(mod._get_name() + " has no attribute `" + item + "`") mod = getattr(mod, item) if not isinstance(mod, torch.nn.Module): - raise AttributeError("`" + item + "` is not " "an nn.Module") + raise AttributeError("`" + item + "` is not an nn.Module") return mod diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py index 10bb44e00..41a712498 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py @@ -52,11 +52,7 @@ from streaming_beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import ( AttributeDict, setup_logger, @@ -95,9 +91,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -163,8 +161,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -272,9 +269,7 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -294,9 +289,7 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -352,9 +345,7 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state( - params.left_context, device=device - ) + initial_states = model.encoder.get_init_state(params.left_context, device=device) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -426,9 +417,7 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -461,8 +450,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -535,8 +523,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py index 66ffbd3ec..598fcf344 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py @@ -90,9 +90,7 @@ def test_conv2d_subsampling(): onnx_y = torch.from_numpy(onnx_y) torch_y = jit_model(x) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( - (onnx_y - torch_y).abs().max() - ) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() os.remove(filename) @@ -147,9 +145,7 @@ def test_rel_pos(): onnx_pos_emb = torch.from_numpy(onnx_pos_emb) torch_y, torch_pos_emb = jit_model(x) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( - (onnx_y - torch_y).abs().max() - ) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() assert torch.allclose(onnx_pos_emb, torch_pos_emb, atol=1e-05), ( (onnx_pos_emb - torch_pos_emb).abs().max() @@ -197,9 +193,7 @@ def test_conformer_encoder_layer(): encoder_layer.eval() encoder_layer = convert_scaled_to_non_scaled(encoder_layer, inplace=True) - jit_model = torch.jit.trace( - encoder_layer, (x, pos_emb, src_key_padding_mask) - ) + jit_model = torch.jit.trace(encoder_layer, (x, pos_emb, src_key_padding_mask)) torch.onnx.export( encoder_layer, @@ -236,9 +230,7 @@ def test_conformer_encoder_layer(): onnx_y = torch.from_numpy(onnx_y) torch_y = jit_model(x, pos_emb, src_key_padding_mask) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( - (onnx_y - torch_y).abs().max() - ) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape) @@ -322,9 +314,7 @@ def test_conformer_encoder(): onnx_y = torch.from_numpy(onnx_y) torch_y = jit_model(x, pos_emb, src_key_padding_mask) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( - (onnx_y - torch_y).abs().max() - ) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape) @@ -379,9 +369,7 @@ def test_conformer(): onnx_y_lens = torch.from_numpy(onnx_y_lens) torch_y, torch_y_lens = jit_model(x, x_lens) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( - (onnx_y - torch_y).abs().max() - ) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() assert torch.allclose(onnx_y_lens, torch_y_lens, atol=1e-05), ( (onnx_y_lens - torch_y_lens).abs().max() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py index 44e96644a..6724343dd 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py @@ -92,9 +92,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -163,8 +161,7 @@ def get_parser(): "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. " - "Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.", ) parser.add_argument( @@ -214,8 +211,7 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need " - "to be changed.", + help="The initial learning rate. This value should not need to be changed.", ) parser.add_argument( @@ -238,42 +234,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -672,9 +671,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -687,14 +684,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -705,9 +697,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -919,9 +909,7 @@ def train_one_epoch( f"train/current_{prefix}_", params.batch_idx_train, ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) libri_tot_loss.write_summary( tb_writer, "train/libri_tot_", params.batch_idx_train ) @@ -967,8 +955,7 @@ def filter_short_and_long_utterances( # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False @@ -1109,9 +1096,7 @@ def run(rank, world_size, args): train_giga_cuts = train_giga_cuts.repeat(times=None) if args.enable_musan: - cuts_musan = load_manifest( - Path(args.manifest_dir) / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") else: cuts_musan = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 4f043e5a6..69cfcd298 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -197,20 +197,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -306,8 +310,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -427,9 +430,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) if ( params.decoding_method == "fast_beam_search" @@ -485,10 +486,7 @@ def decode_one_batch( nbest_scale=params.nbest_scale, return_timestamps=True, ) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: res = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -566,9 +564,7 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[ - str, List[Tuple[str, List[str], List[str], List[float], List[float]]] -]: +) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]: """Decode dataset. Args: @@ -643,9 +639,7 @@ def decode_dataset( cut_ids, hyps, texts, timestamps_hyp, timestamps_ref ): ref_words = ref_text.split() - this_batch.append( - (cut_id, ref_words, hyp_words, time_ref, time_hyp) - ) + this_batch.append((cut_id, ref_words, hyp_words, time_ref, time_hyp)) results[name].extend(this_batch) @@ -654,9 +648,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -694,8 +686,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -722,9 +713,7 @@ def save_results( note = "" logging.info(s) - s = "\nFor {}, symbol-delay of different settings are:\n".format( - test_set_name - ) + s = "\nFor {}, symbol-delay of different settings are:\n".format(test_set_name) note = "\tbest for {}".format(test_set_name) for key, val in test_set_delays: s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note) @@ -773,9 +762,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -812,13 +799,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -841,13 +827,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -875,7 +860,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -902,9 +887,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py index ce7518ceb..bd5801a78 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py @@ -89,20 +89,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -133,8 +137,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -183,13 +186,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -212,13 +214,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -246,7 +247,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -282,9 +283,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py index 7af9ea9b8..a28e52c78 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py @@ -96,20 +96,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -175,8 +179,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -284,9 +287,7 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -306,9 +307,7 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -364,9 +363,7 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state( - params.left_context, device=device - ) + initial_states = model.encoder.get_init_state(params.left_context, device=device) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -438,9 +435,7 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -473,8 +468,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -547,13 +541,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -576,13 +569,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -610,7 +602,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index cf32e565b..76785a845 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -101,9 +101,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -239,42 +237,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -621,11 +622,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -665,9 +662,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -680,14 +675,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -698,9 +688,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -879,9 +867,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -1013,8 +999,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index 427b06294..8499651d7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -214,10 +214,7 @@ class Conformer(EncoderInterface): (num_encoder_layers, cnn_module_kernel - 1, encoder_dim). NOTE: the returned tensors are on the given device. """ - if ( - len(self._init_state) == 2 - and self._init_state[0].size(1) == left_context - ): + if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context: # Note: It is OK to share the init state as it is # not going to be modified by the model return self._init_state @@ -439,9 +436,7 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -459,9 +454,7 @@ class ConformerEncoderLayer(nn.Module): ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) - self.conv_module = ConvolutionModule( - d_model, cnn_module_kernel, causal=causal - ) + self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal) self.norm_final = BasicNorm(d_model) @@ -527,9 +520,7 @@ class ConformerEncoderLayer(nn.Module): src = src + self.dropout(src_att) # convolution module - conv, _ = self.conv_module( - src, src_key_padding_mask=src_key_padding_mask - ) + conv, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask) src = src + self.dropout(conv) # feed forward module @@ -802,9 +793,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -820,9 +809,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x_size_1 * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -848,9 +835,7 @@ class RelPositionalEncoding(torch.nn.Module): pe = torch.cat([pe_positive, pe_negative], dim=1) self.pe = pe.to(device=x.device, dtype=x.dtype) - def forward( - self, x: torch.Tensor, left_context: int = 0 - ) -> Tuple[Tensor, Tensor]: + def forward(self, x: torch.Tensor, left_context: int = 0) -> Tuple[Tensor, Tensor]: """Add positional encoding. Args: @@ -1118,9 +1103,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -1189,33 +1174,25 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor" + " instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -1253,23 +1230,15 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d - matrix_bd = torch.matmul( - q_with_bias_v, p - ) # (batch, head, time1, 2*time1-1) + matrix_bd = torch.matmul(q_with_bias_v, p) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd, left_context) - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -1310,21 +1279,17 @@ class RelPositionMultiheadAttention(nn.Module): ): if attn_mask.size(0) != 1: attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len) - combined_mask = attn_mask | key_padding_mask.unsqueeze( - 1 - ).unsqueeze(2) + combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) else: # attn_mask.shape == (1, tgt_len, src_len) - combined_mask = attn_mask.unsqueeze( - 0 - ) | key_padding_mask.unsqueeze(1).unsqueeze(2) + combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( + 1 + ).unsqueeze(2) attn_output_weights = attn_output_weights.view( bsz, num_heads, tgt_len, src_len ) - attn_output_weights = attn_output_weights.masked_fill( - combined_mask, 0.0 - ) + attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, src_len ) @@ -1336,13 +1301,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -1481,16 +1442,12 @@ class ConvolutionModule(nn.Module): # manualy padding self.lorder zeros to the left x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) else: - assert ( - not self.training - ), "Cache should be None in training time" + assert not self.training, "Cache should be None in training time" assert cache.size(0) == self.lorder x = torch.cat([cache.permute(1, 2, 0), x], dim=2) if right_context > 0: cache = x.permute(2, 0, 1)[ - -(self.lorder + right_context) : ( # noqa - -right_context - ), + -(self.lorder + right_context) : (-right_context), # noqa ..., ] else: @@ -1666,9 +1623,7 @@ class RandomCombine(nn.Module): self.stddev = stddev self.final_log_weight = ( - torch.tensor( - (final_weight / (1 - final_weight)) * (self.num_inputs - 1) - ) + torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)) .log() .item() ) @@ -1765,16 +1720,14 @@ class RandomCombine(nn.Module): # final contains self.num_inputs - 1 in all elements final = torch.full((num_frames,), self.num_inputs - 1, device=device) # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. - nonfinal = torch.randint( - self.num_inputs - 1, (num_frames,), device=device - ) + nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device) indexes = torch.where( torch.rand(num_frames, device=device) < final_prob, final, nonfinal ) - ans = torch.nn.functional.one_hot( - indexes, num_classes=self.num_inputs - ).to(dtype=dtype) + ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to( + dtype=dtype + ) return ans def _get_random_mixed_weights( @@ -1804,7 +1757,8 @@ class RandomCombine(nn.Module): def _test_random_combine(final_weight: float, pure_prob: float, stddev: float): print( - f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}" + f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}," + f" stddev={stddev}" ) num_inputs = 3 num_channels = 50 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 22bcdd88e..f462cc42f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -179,20 +179,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -303,8 +307,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -477,9 +480,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -545,10 +546,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -696,9 +694,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -731,8 +727,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -787,9 +782,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -828,13 +821,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -857,13 +849,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -891,7 +882,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -937,9 +928,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py index b2e5b430e..a739c17bc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py @@ -89,20 +89,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -133,8 +137,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -181,13 +184,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -210,13 +212,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -244,7 +245,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -280,9 +281,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py index 1e100fcbd..e2da0da4c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py @@ -89,9 +89,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -116,10 +118,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -166,8 +170,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -198,10 +201,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -264,15 +266,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -344,9 +342,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py index 6fee9483e..59a0e8fa2 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py @@ -96,20 +96,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -175,8 +179,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -284,9 +287,7 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -306,9 +307,7 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -364,9 +363,7 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state( - params.left_context, device=device - ) + initial_states = model.encoder.get_init_state(params.left_context, device=device) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -438,9 +435,7 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -473,8 +468,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -547,13 +541,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -576,13 +569,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -610,7 +602,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py index 179d9372e..75696d61b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py @@ -89,9 +89,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -248,8 +246,7 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need " - "to be changed.", + help="The initial learning rate. This value should not need to be changed.", ) parser.add_argument( @@ -272,42 +269,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -645,11 +645,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -690,9 +686,7 @@ def compute_loss( # If the batch contains more than 10 utterances AND # if either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -705,14 +699,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -723,9 +712,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -908,9 +895,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -1023,7 +1008,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1045,8 +1030,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py index 53788b3f7..40ad61fd4 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py @@ -90,10 +90,7 @@ class Conformer(EncoderInterface): output_layers = [] if middle_output_layer is not None: - assert ( - middle_output_layer >= 0 - and middle_output_layer < num_encoder_layers - ) + assert middle_output_layer >= 0 and middle_output_layer < num_encoder_layers output_layers.append(middle_output_layer) # The last layer is always needed. @@ -178,9 +175,7 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -362,9 +357,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -379,9 +372,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -656,9 +647,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -727,33 +718,25 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor" + " instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -790,9 +773,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -800,13 +781,9 @@ class RelPositionMultiheadAttention(nn.Module): ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -840,13 +817,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -869,9 +842,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py index 74df04006..600aa9b39 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py @@ -120,20 +120,24 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -208,8 +212,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -267,9 +270,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - layer_results, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + layer_results, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) encoder_out = layer_results[-1] hyps = [] @@ -285,10 +286,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -334,11 +332,7 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps + f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -411,9 +405,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -446,8 +438,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -490,9 +481,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -524,13 +513,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -553,13 +541,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -587,7 +574,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/export.py b/egs/librispeech/ASR/pruned_transducer_stateless6/export.py index cff9c7377..17f8614dc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/export.py @@ -51,11 +51,7 @@ import sentencepiece as spm import torch from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import str2bool @@ -87,9 +83,11 @@ def get_parser(): "--avg", type=int, default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( @@ -120,8 +118,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) return parser @@ -160,8 +157,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -209,9 +205,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py b/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py index 21409287c..86cf34877 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py @@ -21,9 +21,10 @@ import os from pathlib import Path import torch -from vq_utils import CodebookIndexExtractor from asr_datamodule import LibriSpeechAsrDataModule from hubert_xlarge import HubertXlargeFineTuned +from vq_utils import CodebookIndexExtractor + from icefall.utils import AttributeDict, str2bool diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py index 49b557814..b8440f90a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py @@ -23,7 +23,6 @@ from pathlib import Path from typing import Dict, List, Tuple import torch - from asr_datamodule import LibriSpeechAsrDataModule from hubert_xlarge import HubertXlargeFineTuned @@ -99,9 +98,7 @@ def decode_dataset( if batch_idx % 20 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -124,9 +121,7 @@ def save_results( ) test_set_wers[key] = wer - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.res_dir / f"wer-summary-{test_set_name}.txt" @@ -155,9 +150,7 @@ def main(): # reset some parameters needed by hubert. params.update(HubertXlargeFineTuned.get_params()) - params.res_dir = ( - params.exp_dir / f"ctc_greedy_search-{params.teacher_model_id}" - ) + params.res_dir = params.exp_dir / f"ctc_greedy_search-{params.teacher_model_id}" setup_logger(f"{params.res_dir}/log/log-ctc_greedy_search") logging.info("Decoding started") @@ -190,9 +183,7 @@ def main(): params=params, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py index 55ce7b00d..4f9417c9f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py @@ -22,11 +22,7 @@ from pathlib import Path from typing import Dict, List, Tuple import torch -from fairseq import ( - checkpoint_utils, - tasks, - utils, -) +from fairseq import checkpoint_utils, tasks, utils from fairseq.data.data_utils import post_process from omegaconf import OmegaConf @@ -51,9 +47,7 @@ def _load_hubert_model(params: AttributeDict): "data": str(params.hubert_model_dir), } ) - model_path = Path(params.hubert_model_dir) / ( - params.teacher_model_id + ".pt" - ) + model_path = Path(params.hubert_model_dir) / (params.teacher_model_id + ".pt") task = tasks.setup_task(cfg_task) processor = task.target_dictionary models, saved_cfg = checkpoint_utils.load_model_ensemble( @@ -151,9 +145,7 @@ class HubertXlargeFineTuned: supervisions = batch["supervisions"] num_samples = supervisions["num_samples"] B, T = features.shape - padding_mask = torch.arange(0, T).expand(B, T) > num_samples.reshape( - [-1, 1] - ) + padding_mask = torch.arange(0, T).expand(B, T) > num_samples.reshape([-1, 1]) padding_mask = padding_mask.to(self.params.device) features = features.to(self.params.device) @@ -163,9 +155,7 @@ class HubertXlargeFineTuned: features = features.transpose(1, 2) features = self.w2v_model.layer_norm(features) - padding_mask = self.w2v_model.forward_padding_mask( - features, padding_mask - ) + padding_mask = self.w2v_model.forward_padding_mask(features, padding_mask) if self.w2v_model.post_extract_proj is not None: features = self.w2v_model.post_extract_proj(features) @@ -212,9 +202,7 @@ class HubertXlargeFineTuned: toks = encoder_out.argmax(dim=-1) blank = 0 toks = [tok.unique_consecutive() for tok in toks] - hyps = [ - self.processor.string(tok[tok != blank].int().cpu()) for tok in toks - ] + hyps = [self.processor.string(tok[tok != blank].int().cpu()) for tok in toks] hyps = [post_process(hyp, "letter") for hyp in hyps] return hyps diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py index 7716d19cf..daadb70c9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py @@ -69,9 +69,7 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) from icefall import is_module_available @@ -180,9 +178,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens @@ -237,9 +233,7 @@ class Transducer(nn.Module): return (simple_loss, pruned_loss, codebook_loss) @staticmethod - def concat_successive_codebook_indexes( - middle_layer_output, codebook_indexes - ): + def concat_successive_codebook_indexes(middle_layer_output, codebook_indexes): # Output rate of hubert is 50 frames per second, # while that of current encoder is 25. # Following code handling two issues: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py index f717d85fb..be54ff0ce 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py @@ -101,9 +101,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def get_parser(): @@ -203,42 +201,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -569,9 +570,7 @@ def save_checkpoint( def extract_codebook_indexes(batch): cuts = batch["supervisions"]["cut"] # -100 is identical to ignore_value in CE loss computation. - cuts_pre_mixed = [ - c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts - ] + cuts_pre_mixed = [c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts] codebook_indexes, codebook_indexes_lens = collate_custom_field( cuts_pre_mixed, "codebook_indexes", pad_value=-100 ) @@ -604,11 +603,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -655,9 +650,7 @@ def compute_loss( # If the batch contains more than 10 utterances AND # if either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -670,14 +663,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss if is_training and params.enable_distillation: assert codebook_loss is not None loss += params.codebook_loss_scale * codebook_loss @@ -690,9 +678,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -873,9 +859,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -1007,8 +991,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py index 47cf2b14b..40f97f662 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py @@ -68,9 +68,7 @@ class CodebookIndexExtractor: def init_dirs(self): # vq_dir is the root dir for quantization, containing: # training data, trained quantizer, and extracted codebook indexes - self.vq_dir = ( - self.params.exp_dir / f"vq/{self.params.teacher_model_id}/" - ) + self.vq_dir = self.params.exp_dir / f"vq/{self.params.teacher_model_id}/" self.vq_dir.mkdir(parents=True, exist_ok=True) # manifest_dir contains: @@ -208,9 +206,7 @@ class CodebookIndexExtractor: start = cur_offset % (data.shape[0] + 1 - B) end = start + B cur_offset += B - yield data[start:end, :].to(self.params.device).to( - dtype=torch.float - ) + yield data[start:end, :].to(self.params.device).to(dtype=torch.float) for x in minibatch_generator(train, repeat=True): trainer.step(x) @@ -227,10 +223,11 @@ class CodebookIndexExtractor: """ for subset in self.params.subsets: logging.info(f"About to split {subset}.") - ori_manifest = ( - f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" + ori_manifest = f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" + split_cmd = ( + "lhotse split" + f" {self.params.world_size} {ori_manifest} {self.manifest_dir}" ) - split_cmd = f"lhotse split {self.params.world_size} {ori_manifest} {self.manifest_dir}" os.system(f"{split_cmd}") def join_manifests(self): @@ -240,16 +237,13 @@ class CodebookIndexExtractor: logging.info("Start to join manifest files.") for subset in self.params.subsets: vq_manifest_path = ( - self.dst_manifest_dir - / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" + self.dst_manifest_dir / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" ) ori_manifest_path = ( - self.ori_manifest_dir - / f"librispeech_cuts_train-{subset}.jsonl.gz" + self.ori_manifest_dir / f"librispeech_cuts_train-{subset}.jsonl.gz" ) dst_vq_manifest_path = ( - self.dst_manifest_dir - / f"librispeech_cuts_train-{subset}.jsonl.gz" + self.dst_manifest_dir / f"librispeech_cuts_train-{subset}.jsonl.gz" ) cuts_vq = load_manifest(vq_manifest_path) cuts_ori = load_manifest(ori_manifest_path) @@ -269,8 +263,7 @@ class CodebookIndexExtractor: for subset in self.params.subsets: vq_manifests = f"{self.manifest_dir}/with_codebook_indexes-librispeech-cuts_train-{subset}*.jsonl.gz" dst_vq_manifest = ( - self.dst_manifest_dir - / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" + self.dst_manifest_dir / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" ) if 1 == self.params.world_size: merge_cmd = f"cp {vq_manifests} {dst_vq_manifest}" @@ -330,9 +323,7 @@ class CodebookIndexExtractor: def load_ori_dl(self, subset): if self.params.world_size == 1: - ori_manifest_path = ( - f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" - ) + ori_manifest_path = f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" else: ori_manifest_path = ( self.manifest_dir diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py index 06c5863f1..fa8144935 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py @@ -164,20 +164,24 @@ def get_parser(): "--avg", type=int, default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -272,8 +276,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -393,9 +396,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -454,10 +455,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -588,9 +586,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -623,8 +619,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -679,9 +674,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -718,13 +711,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -747,13 +739,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -781,7 +772,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -808,9 +799,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py index 712dc8ce1..5f90e6375 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py @@ -69,7 +69,7 @@ class Decoder(nn.Module): out_channels=decoder_dim, kernel_size=context_size, padding=0, - groups=decoder_dim//4, # group size == 4 + groups=decoder_dim // 4, # group size == 4 bias=False, ) @@ -91,9 +91,7 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py index 5744ea3ea..43ac658e5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py @@ -129,20 +129,24 @@ def get_parser(): "--avg", type=int, default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -176,8 +180,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) add_model_arguments(parser) @@ -215,13 +218,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -244,13 +246,12 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -278,7 +279,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -316,9 +317,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py index e2405d5ef..c94a34d58 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py @@ -69,10 +69,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) return parser @@ -93,10 +95,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -267,9 +268,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py index 7d8de5afe..3ddac2cf2 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py @@ -56,9 +56,7 @@ class Joiner(nn.Module): assert encoder_out.shape[:-1] == decoder_out.shape[:-1] if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj( - decoder_out - ) + logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) else: logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py index 53cde6c6f..0e59b0f2f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py @@ -15,14 +15,15 @@ # limitations under the License. +import random + import k2 import torch import torch.nn as nn -import random from encoder_interface import EncoderInterface +from scaling import penalize_abs_values_gt from icefall.utils import add_sos -from scaling import penalize_abs_values_gt class Transducer(nn.Module): @@ -65,7 +66,8 @@ class Transducer(nn.Module): self.joiner = joiner self.simple_am_proj = nn.Linear( - encoder_dim, vocab_size, + encoder_dim, + vocab_size, ) self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size) @@ -133,18 +135,16 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - #if self.training and random.random() < 0.25: + # if self.training and random.random() < 0.25: # lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04) - #if self.training and random.random() < 0.25: + # if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) with torch.cuda.amp.autocast(enabled=False): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index bb8b0a0e3..460ac2c3e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -14,17 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import defaultdict -from typing import List, Optional, Union, Tuple, List -from lhotse.utils import fix_random_seed -import torch -from scaling import ActivationBalancer +import contextlib +import logging import random +from collections import defaultdict +from typing import List, Optional, Tuple, Union + +import torch +from lhotse.utils import fix_random_seed +from scaling import ActivationBalancer from torch import Tensor from torch.optim import Optimizer -import logging -import contextlib - class BatchedOptimizer(Optimizer): @@ -37,11 +37,10 @@ class BatchedOptimizer(Optimizer): Args: params: """ + def __init__(self, params, defaults): super(BatchedOptimizer, self).__init__(params, defaults) - - @contextlib.contextmanager def batched_params(self, param_group): """ @@ -73,7 +72,9 @@ class BatchedOptimizer(Optimizer): group: a parameter group, which is a list of parameters; should be one of self.groups. """ - batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter + batches = defaultdict( + list + ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter for p in param_group: key = (str(p.dtype), *p.shape) @@ -82,7 +83,7 @@ class BatchedOptimizer(Optimizer): stacked_params_dict = dict() # turn batches into a list, in deterministic order. - batches = [ batches[key] for key in sorted(batches.keys()) ] + batches = [batches[key] for key in sorted(batches.keys())] # pairs will contain pairs of (stacked_param, state), one for each batch # in `batches`. pairs = [] @@ -94,76 +95,77 @@ class BatchedOptimizer(Optimizer): # group. class Optimizer will take care of saving/loading state. state = self.state[p] p_stacked = torch.stack(batch) - grad = torch.stack([torch.zeros_like(p) if p.grad is None else p.grad for p in batch ]) + grad = torch.stack( + [torch.zeros_like(p) if p.grad is None else p.grad for p in batch] + ) p_stacked.grad = grad stacked_params_dict[key] = p_stacked pairs.append((p_stacked, state)) - yield pairs # <-- calling code will do the actual optimization here! + yield pairs # <-- calling code will do the actual optimization here! for ((stacked_params, _state), batch) in zip(pairs, batches): for i, p in enumerate(batch): # batch is list of Parameter p.copy_(stacked_params[i]) - class ScaledAdam(BatchedOptimizer): """ - Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update - proportional to the norm of that parameter; and also learn the scale of the parameter, - in log space, subject to upper and lower limits (as if we had factored each parameter as - param = underlying_param * log_scale.exp()) + Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update + proportional to the norm of that parameter; and also learn the scale of the parameter, + in log space, subject to upper and lower limits (as if we had factored each parameter as + param = underlying_param * log_scale.exp()) - Args: - params: The parameters or param_groups to optimize (like other Optimizer subclasses) - lr: The learning rate. We will typically use a learning rate schedule that starts - at 0.03 and decreases over time, i.e. much higher than other common - optimizers. - clipping_scale: (e.g. 2.0) - A scale for gradient-clipping: if specified, the normalized gradients - over the whole model will be clipped to have 2-norm equal to - `clipping_scale` times the median 2-norm over the most recent period - of `clipping_update_period` minibatches. By "normalized gradients", - we mean after multiplying by the rms parameter value for this tensor - [for non-scalars]; this is appropriate because our update is scaled - by this quantity. - betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. - Must satisfy 0 < beta <= beta2 < 1. - scalar_lr_scale: A scaling factor on the learning rate, that we use to update the - scale of each parameter tensor and scalar parameters of the mode.. - If each parameter were decomposed - as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale - would be a the scaling factor on the learning rate of p_scale. - eps: A general-purpose epsilon to prevent division by zero - param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be >= this value) - param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be <= this value) - scalar_max: Maximum absolute value for scalar parameters (applicable if your - model has any parameters with numel() == 1). - size_update_period: The periodicity, in steps, with which we update the size (scale) - of the parameter tensor. This is provided to save a little time - in the update. - clipping_update_period: if clipping_scale is specified, this is the period + Args: + params: The parameters or param_groups to optimize (like other Optimizer subclasses) + lr: The learning rate. We will typically use a learning rate schedule that starts + at 0.03 and decreases over time, i.e. much higher than other common + optimizers. + clipping_scale: (e.g. 2.0) + A scale for gradient-clipping: if specified, the normalized gradients + over the whole model will be clipped to have 2-norm equal to + `clipping_scale` times the median 2-norm over the most recent period + of `clipping_update_period` minibatches. By "normalized gradients", + we mean after multiplying by the rms parameter value for this tensor + [for non-scalars]; this is appropriate because our update is scaled + by this quantity. + betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. + Must satisfy 0 < beta <= beta2 < 1. + scalar_lr_scale: A scaling factor on the learning rate, that we use to update the + scale of each parameter tensor and scalar parameters of the mode.. + If each parameter were decomposed + as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale + would be a the scaling factor on the learning rate of p_scale. + eps: A general-purpose epsilon to prevent division by zero + param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be >= this value) + param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be <= this value) + scalar_max: Maximum absolute value for scalar parameters (applicable if your + model has any parameters with numel() == 1). + size_update_period: The periodicity, in steps, with which we update the size (scale) + of the parameter tensor. This is provided to save a little time + in the update. + clipping_update_period: if clipping_scale is specified, this is the period """ - def __init__( - self, - params, - lr=3e-02, - clipping_scale=None, - betas=(0.9, 0.98), - scalar_lr_scale=0.1, - eps=1.0e-08, - param_min_rms=1.0e-05, - param_max_rms=3.0, - scalar_max=10.0, - size_update_period=4, - clipping_update_period=100, - ): + def __init__( + self, + params, + lr=3e-02, + clipping_scale=None, + betas=(0.9, 0.98), + scalar_lr_scale=0.1, + eps=1.0e-08, + param_min_rms=1.0e-05, + param_max_rms=3.0, + scalar_max=10.0, + size_update_period=4, + clipping_update_period=100, + ): defaults = dict( lr=lr, @@ -183,7 +185,6 @@ class ScaledAdam(BatchedOptimizer): def __setstate__(self, state): super(ScaledAdam, self).__setstate__(state) - @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. @@ -206,7 +207,9 @@ class ScaledAdam(BatchedOptimizer): # a regular parameter, and will have a .grad, but the 1st dim corresponds to # a stacking dim, it is not a real dim. - if len(batches[0][1]) == 0: # if len(first state) == 0: not yet initialized + if ( + len(batches[0][1]) == 0 + ): # if len(first state) == 0: not yet initialized clipping_scale = 1 else: clipping_scale = self._get_clipping_scale(group, batches) @@ -225,13 +228,9 @@ class ScaledAdam(BatchedOptimizer): self._step_one_batch(group, p, state, clipping_scale) - return loss - def _init_state(self, - group: dict, - p: Tensor, - state: dict): + def _init_state(self, group: dict, p: Tensor, state: dict): """ Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p is actually the batch dimension, corresponding to batched-together @@ -247,7 +246,7 @@ class ScaledAdam(BatchedOptimizer): state["step"] = 0 - kwargs = {'device':p.device, 'dtype':p.dtype} + kwargs = {"device": p.device, "dtype": p.dtype} # 'delta' implements conventional momentum. There are # several different kinds of update going on, so rather than @@ -255,36 +254,30 @@ class ScaledAdam(BatchedOptimizer): # parameter-change "delta", which combines all forms of # update. this is equivalent to how it's done in Adam, # except for the first few steps. - state["delta"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) + state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format) batch_size = p.shape[0] numel = p.numel() // batch_size numel = p.numel() - if numel > 1: # "param_rms" just periodically records the scalar root-mean-square value of # the parameter tensor. # it has a shape like (batch_size, 1, 1, 1, 1) - param_rms = (p**2).mean(dim=list(range(1, p.ndim)), - keepdim=True).sqrt() + param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() state["param_rms"] = param_rms state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) - state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape, - **kwargs) - + state["scale_grads"] = torch.zeros( + size_update_period, *param_rms.shape, **kwargs + ) # exp_avg_sq is the weighted sum of scaled gradients. as in Adam. - state["exp_avg_sq"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) + state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) - def _get_clipping_scale(self, - group: dict, - pairs: List[Tuple[Tensor, dict]]) -> float: + def _get_clipping_scale( + self, group: dict, pairs: List[Tuple[Tensor, dict]] + ) -> float: """ Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients by this amount before applying the rest of the update. @@ -314,57 +307,67 @@ class ScaledAdam(BatchedOptimizer): if p.numel() == p.shape[0]: # a batch of scalars tot_sumsq += (grad**2).sum() # sum() to change shape [1] to [] else: - tot_sumsq += ((grad * state["param_rms"])**2).sum() + tot_sumsq += ((grad * state["param_rms"]) ** 2).sum() tot_norm = tot_sumsq.sqrt() - if not "model_norms" in first_state: - first_state["model_norms"] = torch.zeros(clipping_update_period, - device=p.device) + if "model_norms" not in first_state: + first_state["model_norms"] = torch.zeros( + clipping_update_period, device=p.device + ) first_state["model_norms"][step % clipping_update_period] = tot_norm if step % clipping_update_period == 0: # Print some stats. # We don't reach here if step == 0 because we would have returned # above. - sorted_norms = first_state["model_norms"].sort()[0].to('cpu') + sorted_norms = first_state["model_norms"].sort()[0].to("cpu") quartiles = [] for n in range(0, 5): - index = min(clipping_update_period - 1, - (clipping_update_period // 4) * n) + index = min( + clipping_update_period - 1, + (clipping_update_period // 4) * n, + ) quartiles.append(sorted_norms[index].item()) median = quartiles[2] threshold = clipping_scale * median first_state["model_norm_threshold"] = threshold - percent_clipped = (first_state["num_clipped"] * 100.0 / clipping_update_period - if "num_clipped" in first_state else 0.0) + percent_clipped = ( + first_state["num_clipped"] * 100.0 / clipping_update_period + if "num_clipped" in first_state + else 0.0 + ) first_state["num_clipped"] = 0 - quartiles = ' '.join([ '%.3e' % x for x in quartiles ]) - logging.info(f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " - f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}") + quartiles = " ".join(["%.3e" % x for x in quartiles]) + logging.info( + f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " + f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}" + ) if step < clipping_update_period: return 1.0 # We have not yet estimated a norm to clip to. else: try: model_norm_threshold = first_state["model_norm_threshold"] - except: - logging.info("Warning: model_norm_threshold not in state: possibly " - "you changed config when restarting, adding clipping_scale option?") + except KeyError: + logging.info( + "Warning: model_norm_threshold not in state: possibly " + "you changed config when restarting, adding clipping_scale option?" + ) return 1.0 - ans = min(1.0,(model_norm_threshold / (tot_norm + 1.0e-20)).item()) + ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item()) if ans < 1.0: first_state["num_clipped"] += 1 if ans < 0.1: - logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}") + logging.warn( + f"Scaling gradients by {ans}," + f" model_norm_threshold={model_norm_threshold}" + ) return ans - - def _step_one_batch(self, - group: dict, - p: Tensor, - state: dict, - clipping_scale: float): + def _step_one_batch( + self, group: dict, p: Tensor, state: dict, clipping_scale: float + ): """ Do the step for one parameter, which is actually going to be a batch of `real` parameters, with dim 0 as the batch dim. @@ -391,17 +394,18 @@ class ScaledAdam(BatchedOptimizer): # Update the size/scale of p, and set param_rms scale_grads = state["scale_grads"] scale_grads[step % size_update_period] = (p * grad).sum( - dim=list(range(1, p.ndim)), keepdim=True) + dim=list(range(1, p.ndim)), keepdim=True + ) if step % size_update_period == size_update_period - 1: param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..) - param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), - keepdim=True).sqrt()) + param_rms.copy_( + (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() + ) if step > 0: # self._size_update() learns the overall scale on the # parameter, by shrinking or expanding it. self._size_update(group, scale_grads, p, state) - if numel == 1: # For parameters with 1 element we just use regular Adam. # Updates delta. @@ -411,24 +415,21 @@ class ScaledAdam(BatchedOptimizer): state["step"] = step + 1 - - def _size_update(self, - group: dict, - scale_grads: Tensor, - p: Tensor, - state: dict) -> None: + def _size_update( + self, group: dict, scale_grads: Tensor, p: Tensor, state: dict + ) -> None: """ - Called only where p.numel() > 1, this updates the scale of the parameter. - If we imagine: p = underlying_param * scale.exp(), and we are doing - gradient descent on underlying param and on scale, this function does the update - on `scale`. + Called only where p.numel() > 1, this updates the scale of the parameter. + If we imagine: p = underlying_param * scale.exp(), and we are doing + gradient descent on underlying param and on scale, this function does the update + on `scale`. - Args: - group: dict to look up configuration values - scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing - grads w.r.t. the scales. - p: The parameter to update - state: The state-dict of p + Args: + group: dict to look up configuration values + scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing + grads w.r.t. the scales. + p: The parameter to update + state: The state-dict of p """ param_rms = state["param_rms"] @@ -443,25 +444,28 @@ class ScaledAdam(BatchedOptimizer): size_update_period = scale_grads.shape[0] # correct beta2 for the size update period: we will have # faster decay at this level. - beta2_corr = beta2 ** size_update_period + beta2_corr = beta2**size_update_period scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..) scale_exp_avg_sq.mul_(beta2_corr).add_( - (scale_grads ** 2).mean(dim=0), # mean over dim `size_update_period` - alpha=1-beta2_corr) # shape is (batch_size, 1, 1, ...) + (scale_grads**2).mean(dim=0), # mean over dim `size_update_period` + alpha=1 - beta2_corr, + ) # shape is (batch_size, 1, 1, ...) # The 1st time we reach here is when size_step == 1. size_step = (step + 1) // size_update_period - bias_correction2 = 1 - beta2_corr ** size_step + bias_correction2 = 1 - beta2_corr**size_step # we don't bother with bias_correction1; this will help prevent divergence # at the start of training. denom = scale_exp_avg_sq.sqrt() + eps - scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum(dim=0) / denom + scale_step = ( + -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom + ) - is_too_small = (param_rms < param_min_rms) - is_too_large = (param_rms > param_max_rms) + is_too_small = param_rms < param_min_rms + is_too_large = param_rms > param_max_rms # when the param gets too small, just don't shrink it any further. scale_step.masked_fill_(is_too_small, 0.0) @@ -469,13 +473,9 @@ class ScaledAdam(BatchedOptimizer): scale_step.masked_fill_(is_too_large, -size_lr * size_update_period) delta = state["delta"] # the factor of (1-beta1) relates to momentum. - delta.add_(p * scale_step, alpha=(1-beta1)) + delta.add_(p * scale_step, alpha=(1 - beta1)) - - def _step(self, - group: dict, - p: Tensor, - state: dict): + def _step(self, group: dict, p: Tensor, state: dict): """ This function does the core update of self.step(), in the case where the members of the batch have more than 1 element. @@ -496,8 +496,7 @@ class ScaledAdam(BatchedOptimizer): step = state["step"] exp_avg_sq = state["exp_avg_sq"] - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, - value=(1-beta2)) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0) bias_correction2 = 1 - beta2 ** (this_step + 1) @@ -509,17 +508,13 @@ class ScaledAdam(BatchedOptimizer): denom += eps grad = grad / denom - alpha = -lr * (1-beta1) * state["param_rms"].clamp(min=param_min_rms) + alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms) delta = state["delta"] delta.add_(grad * alpha) p.add_(delta) - - def _step_scalar(self, - group: dict, - p: Tensor, - state: dict): + def _step_scalar(self, group: dict, p: Tensor, state: dict): """ A simplified form of the core update for scalar tensors, where we cannot get a good estimate of the parameter rms. @@ -531,8 +526,7 @@ class ScaledAdam(BatchedOptimizer): grad = p.grad exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, - value=1-beta2) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # bias_correction2 is like in Adam. Don't bother with bias_correction1; # slower update at the start will help stability anyway. @@ -540,12 +534,11 @@ class ScaledAdam(BatchedOptimizer): denom = (exp_avg_sq / bias_correction2).sqrt() + eps delta = state["delta"] - delta.add_(grad / denom, alpha=-lr*(1-beta1)) + delta.add_(grad / denom, alpha=-lr * (1 - beta1)) p.clamp_(min=-scalar_max, max=scalar_max) p.add_(delta) - class LRScheduler(object): """ Base-class for learning rate schedulers where the learning-rate depends on both the @@ -555,18 +548,14 @@ class LRScheduler(object): def __init__(self, optimizer: Optimizer, verbose: bool = False): # Attach optimizer if not isinstance(optimizer, Optimizer): - raise TypeError( - "{} is not an Optimizer".format(type(optimizer).__name__) - ) + raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) self.optimizer = optimizer self.verbose = verbose for group in optimizer.param_groups: group.setdefault("base_lr", group["lr"]) - self.base_lrs = [ - group["base_lr"] for group in optimizer.param_groups - ] + self.base_lrs = [group["base_lr"] for group in optimizer.param_groups] self.epoch = 0 self.batch = 0 @@ -680,13 +669,15 @@ class Eden(LRScheduler): def get_lr(self): factor = ( - (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 + (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 ) ** -0.25 * ( - ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) - ** -0.25 + ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 + ) + warmup_factor = ( + 1.0 + if self.batch >= self.warmup_batches + else 0.5 + 0.5 * (self.batch / self.warmup_batches) ) - warmup_factor = (1.0 if self.batch >= self.warmup_batches - else 0.5 + 0.5 * (self.batch / self.warmup_batches)) return [x * factor * warmup_factor for x in self.base_lrs] @@ -745,13 +736,14 @@ class Eve(Optimizer): parameters, if they fall below this we will stop applying weight decay. - .. _Adam\: A Method for Stochastic Optimization: + .. _Adam: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 .. _Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101 .. _On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ """ + def __init__( self, params, @@ -766,17 +758,11 @@ class Eve(Optimizer): if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: - raise ValueError( - "Invalid beta parameter at index 0: {}".format(betas[0]) - ) + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: - raise ValueError( - "Invalid beta parameter at index 1: {}".format(betas[1]) - ) + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) if not 0 <= weight_decay <= 0.1: - raise ValueError( - "Invalid weight_decay value: {}".format(weight_decay) - ) + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) if not 0 < target_rms <= 10.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) defaults = dict( @@ -812,9 +798,7 @@ class Eve(Optimizer): # Perform optimization step grad = p.grad if grad.is_sparse: - raise RuntimeError( - "AdamW does not support sparse gradients" - ) + raise RuntimeError("AdamW does not support sparse gradients") state = self.state[p] @@ -841,7 +825,7 @@ class Eve(Optimizer): # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( + denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( group["eps"] ) @@ -852,30 +836,31 @@ class Eve(Optimizer): if p.numel() > 1: # avoid applying this weight-decay on "scaling factors" # (which are scalar). - is_above_target_rms = p.norm() > ( - target_rms * (p.numel() ** 0.5) - ) + is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) p.mul_(1 - (weight_decay * is_above_target_rms)) p.addcdiv_(exp_avg, denom, value=-step_size) if random.random() < 0.0005: - step = (exp_avg/denom) * step_size - logging.info(f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}") - + step = (exp_avg / denom) * step_size + logging.info( + f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}" + ) return loss def _test_scaled_adam(hidden_dim: int): import timeit + from scaling import ScaledLinear + E = 100 B = 4 T = 2 logging.info("in test_eve_cain") - #device = torch.device('cuda') - device = torch.device('cpu') + # device = torch.device('cuda') + device = torch.device("cpu") dtype = torch.float32 fix_random_seed(42) @@ -889,79 +874,93 @@ def _test_scaled_adam(hidden_dim: int): fix_random_seed(42) Linear = torch.nn.Linear if iter == 0 else ScaledLinear - m = torch.nn.Sequential(Linear(E, hidden_dim), - torch.nn.PReLU(), - Linear(hidden_dim, hidden_dim), - torch.nn.PReLU(), - Linear(hidden_dim, E), - ).to(device) + m = torch.nn.Sequential( + Linear(E, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, E), + ).to(device) - train_pairs = [ (100.0 * torch.randn(B, T, E, device=device, dtype=dtype) * input_magnitudes, - torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes) for _ in range(20) ] + train_pairs = [ + ( + 100.0 + * torch.randn(B, T, E, device=device, dtype=dtype) + * input_magnitudes, + torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes, + ) + for _ in range(20) + ] - if iter == 0: optim = Eve(m.parameters(), lr=0.003) - elif iter == 1: optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) + if iter == 0: + optim = Eve(m.parameters(), lr=0.003) + elif iter == 1: + optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) - start = timeit.default_timer() avg_loss = 0.0 for epoch in range(180): scheduler.step_epoch() - #if epoch == 100 and iter in [2,3]: + # if epoch == 100 and iter in [2,3]: # optim.reset_speedup() # check it doesn't crash. - #if epoch == 130: + # if epoch == 130: # opts = diagnostics.TensorDiagnosticOptions( # 2 ** 22 # ) # allow 4 megabytes per sub-module # diagnostic = diagnostics.attach_diagnostics(m, opts) - - for n, (x,y) in enumerate(train_pairs): + for n, (x, y) in enumerate(train_pairs): y_out = m(x) - loss = ((y_out - y)**2).mean() * 100.0 + loss = ((y_out - y) ** 2).mean() * 100.0 if epoch == 0 and n == 0: avg_loss = loss.item() else: avg_loss = 0.98 * avg_loss + 0.02 * loss.item() if n == 0 and epoch % 5 == 0: - #norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() - #norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() - #norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() - #norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() - #scale1 = '%.2e' % (m[0].weight_scale.exp().item()) - #scale1b = '%.2e' % (m[0].bias_scale.exp().item()) - #scale2 = '%.2e' % (m[2].weight_scale.exp().item()) - #scale2b = '%.2e' % (m[2].bias_scale.exp().item()) + # norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() + # norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() + # norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() + # norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() + # scale1 = '%.2e' % (m[0].weight_scale.exp().item()) + # scale1b = '%.2e' % (m[0].bias_scale.exp().item()) + # scale2 = '%.2e' % (m[2].weight_scale.exp().item()) + # scale2b = '%.2e' % (m[2].bias_scale.exp().item()) lr = scheduler.get_last_lr()[0] - logging.info(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}") #, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} + logging.info( + f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss" + f" {avg_loss:.4g}, lr={lr:.4e}" + ) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} loss.log().backward() optim.step() optim.zero_grad() scheduler.step_batch() - #diagnostic.print_diagnostics() + # diagnostic.print_diagnostics() stop = timeit.default_timer() logging.info(f"Iter={iter}, Time taken: {stop - start}") logging.info(f"last lr = {scheduler.get_last_lr()}") - #logging.info("state dict = ", scheduler.state_dict()) - #logging.info("optim state_dict = ", optim.state_dict()) + # logging.info("state dict = ", scheduler.state_dict()) + # logging.info("optim state_dict = ", optim.state_dict()) logging.info(f"input_magnitudes = {input_magnitudes}") logging.info(f"output_magnitudes = {output_magnitudes}") - if __name__ == "__main__": torch.set_num_threads(1) torch.set_num_interop_threads(1) logging.getLogger().setLevel(logging.INFO) import subprocess - s = subprocess.check_output("git status -uno .; git log -1; git diff HEAD .", shell=True) + + s = subprocess.check_output( + "git status -uno .; git log -1; git diff HEAD .", shell=True + ) logging.info(s) import sys + if len(sys.argv) > 1: hidden_dim = int(sys.argv[1]) else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py index 7fe1e681a..8b4d88871 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py @@ -100,9 +100,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -127,10 +129,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -177,8 +181,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -209,10 +212,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -275,15 +277,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -355,9 +353,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 50cedba56..4040065e1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -16,12 +16,12 @@ import collections +import logging +import random +from functools import reduce from itertools import repeat from typing import Optional, Tuple, Union -from functools import reduce -import logging -import random import torch import torch.nn as nn import torch.nn.functional as F @@ -32,27 +32,24 @@ from torch.nn import Embedding as ScaledEmbedding class ActivationBalancerFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - scale_factor: Tensor, - sign_factor: Optional[Tensor], - channel_dim: int, + ctx, + x: Tensor, + scale_factor: Tensor, + sign_factor: Optional[Tensor], + channel_dim: int, ) -> Tensor: if channel_dim < 0: channel_dim += x.ndim ctx.channel_dim = channel_dim - xgt0 = (x > 0) + xgt0 = x > 0 if sign_factor is None: ctx.save_for_backward(xgt0, scale_factor) else: ctx.save_for_backward(xgt0, scale_factor, sign_factor) return x - @staticmethod - def backward( - ctx, x_grad: Tensor - ) -> Tuple[Tensor, None, None, None]: + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: if len(ctx.saved_tensors) == 3: xgt0, scale_factor, sign_factor = ctx.saved_tensors for _ in range(ctx.channel_dim, x_grad.ndim - 1): @@ -65,14 +62,22 @@ class ActivationBalancerFunction(torch.autograd.Function): scale_factor = scale_factor.unsqueeze(-1) factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5) neg_delta_grad = x_grad.abs() * factor - return x_grad - neg_delta_grad, None, None, None, + return ( + x_grad - neg_delta_grad, + None, + None, + None, + ) -def _compute_scale_factor(x: Tensor, - channel_dim: int, - min_abs: float, - max_abs: float, - gain_factor: float, - max_factor: float) -> Tensor: + +def _compute_scale_factor( + x: Tensor, + channel_dim: int, + min_abs: float, + max_abs: float, + gain_factor: float, + max_factor: float, +) -> Tensor: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] @@ -83,71 +88,76 @@ def _compute_scale_factor(x: Tensor, else: # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if # x_abs)_mean , min_abs. - below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(min=0, max=max_factor) + below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp( + min=0, max=max_factor + ) - above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(min=0, max=max_factor) + above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp( + min=0, max=max_factor + ) return below_threshold - above_threshold -def _compute_sign_factor(x: Tensor, - channel_dim: int, - min_positive: float, - max_positive: float, - gain_factor: float, - max_factor: float) -> Tensor: + +def _compute_sign_factor( + x: Tensor, + channel_dim: int, + min_positive: float, + max_positive: float, + gain_factor: float, + max_factor: float, +) -> Tensor: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] - proportion_positive = torch.mean((x > 0).to(torch.float32), - dim=sum_dims) + proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims) if min_positive == 0.0: factor1 = 0.0 else: # 0 if proportion_positive >= min_positive, else can be # as large as max_factor. - factor1 = ((min_positive - proportion_positive) * - (gain_factor / min_positive)).clamp_(min=0, max=max_factor) + factor1 = ( + (min_positive - proportion_positive) * (gain_factor / min_positive) + ).clamp_(min=0, max=max_factor) if max_positive == 1.0: factor2 = 0.0 else: # 0 if self.proportion_positive <= max_positive, else can be # as large as -max_factor. - factor2 = ((proportion_positive - max_positive) * - (gain_factor / (1.0 - max_positive))).clamp_(min=0, max=max_factor) + factor2 = ( + (proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive)) + ).clamp_(min=0, max=max_factor) sign_factor = factor1 - factor2 # require min_positive != 0 or max_positive != 1: assert not isinstance(sign_factor, float) return sign_factor - class ActivationScaleBalancerFunction(torch.autograd.Function): """ This object is used in class ActivationBalancer when the user specified min_positive=0, max_positive=1, so there are no constraints on the signs of the activations and only the absolute value has a constraint. """ + @staticmethod def forward( - ctx, - x: Tensor, - sign_factor: Tensor, - scale_factor: Tensor, - channel_dim: int, + ctx, + x: Tensor, + sign_factor: Tensor, + scale_factor: Tensor, + channel_dim: int, ) -> Tensor: if channel_dim < 0: channel_dim += x.ndim ctx.channel_dim = channel_dim - xgt0 = (x > 0) + xgt0 = x > 0 ctx.save_for_backward(xgt0, sign_factor, scale_factor) return x - @staticmethod - def backward( - ctx, x_grad: Tensor - ) -> Tuple[Tensor, None, None, None]: + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: xgt0, sign_factor, scale_factor = ctx.saved_tensors for _ in range(ctx.channel_dim, x_grad.ndim - 1): sign_factor = sign_factor.unsqueeze(-1) @@ -155,18 +165,24 @@ class ActivationScaleBalancerFunction(torch.autograd.Function): factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5) neg_delta_grad = x_grad.abs() * factor - return x_grad - neg_delta_grad, None, None, None, + return ( + x_grad - neg_delta_grad, + None, + None, + None, + ) class RandomClampFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - min: Optional[float], - max: Optional[float], - prob: float, - reflect: float) -> Tensor: + ctx, + x: Tensor, + min: Optional[float], + max: Optional[float], + prob: float, + reflect: float, + ) -> Tensor: x_clamped = torch.clamp(x, min=min, max=max) mask = torch.rand_like(x) < prob ans = torch.where(mask, x_clamped, x) @@ -179,30 +195,32 @@ class RandomClampFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None, None]: - is_same, = ctx.saved_tensors + (is_same,) = ctx.saved_tensors x_grad = ans_grad * is_same.to(ans_grad.dtype) reflect = ctx.reflect - if reflect != 0.0: + if reflect != 0.0: x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect) return x_grad, None, None, None, None -def random_clamp(x: Tensor, - min: Optional[float] = None, - max: Optional[float] = None, - prob: float = 0.5, - reflect: float = 0.0): + +def random_clamp( + x: Tensor, + min: Optional[float] = None, + max: Optional[float] = None, + prob: float = 0.5, + reflect: float = 0.0, +): return RandomClampFunction.apply(x, min, max, prob, reflect) -def random_cast_to_half(x: Tensor, - min_abs: float = 5.0e-06) -> Tensor: +def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor: """ A randomized way of casting a floating point value to half precision. """ if x.dtype == torch.float16: return x x_abs = x.abs() - is_too_small = (x_abs < min_abs) + is_too_small = x_abs < min_abs # for elements where is_too_small is true, random_val will contain +-min_abs with # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations, # for those elements]. @@ -215,6 +233,7 @@ class RandomGradFunction(torch.autograd.Function): Does nothing in forward pass; in backward pass, gets rid of very small grads using randomized approach that preserves expectations (intended to reduce roundoff). """ + @staticmethod def forward(ctx, x: Tensor, min_abs: float) -> Tensor: ctx.min_abs = min_abs @@ -223,35 +242,37 @@ class RandomGradFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]: if ans_grad.dtype == torch.float16: - return random_cast_to_half(ans_grad.to(torch.float32), - min_abs=ctx.min_abs), None + return ( + random_cast_to_half(ans_grad.to(torch.float32), min_abs=ctx.min_abs), + None, + ) else: return ans_grad, None + class RandomGrad(torch.nn.Module): """ Gets rid of very small gradients using an expectation-preserving method, intended to increase accuracy of training when using amp (automatic mixed precision) """ - def __init__(self, - min_abs: float = 5.0e-06): + + def __init__(self, min_abs: float = 5.0e-06): super(RandomGrad, self).__init__() self.min_abs = min_abs - def forward(self, - x: Tensor): + def forward(self, x: Tensor): if torch.jit.is_scripting() or not self.training: return x else: return RandomGradFunction.apply(x, self.min_abs) - class SoftmaxFunction(torch.autograd.Function): """ Tries to handle half-precision derivatives in a randomized way that should be more accurate for training than the default behavior. """ + @staticmethod def forward(ctx, x: Tensor, dim: int): ans = x.softmax(dim=dim) @@ -267,7 +288,7 @@ class SoftmaxFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor): - ans, = ctx.saved_tensors + (ans,) = ctx.saved_tensors with torch.cuda.amp.autocast(enabled=False): ans_grad = ans_grad.to(torch.float32) ans = ans.to(torch.float32) @@ -276,9 +297,7 @@ class SoftmaxFunction(torch.autograd.Function): return x_grad, None - -def softmax(x: Tensor, - dim: int): +def softmax(x: Tensor, dim: int): if torch.jit.is_scripting(): return x.softmax(dim) @@ -288,20 +307,18 @@ def softmax(x: Tensor, class MaxEigLimiterFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - coeffs: Tensor, - direction: Tensor, - channel_dim: int, - grad_scale: float) -> Tensor: + ctx, + x: Tensor, + coeffs: Tensor, + direction: Tensor, + channel_dim: int, + grad_scale: float, + ) -> Tensor: ctx.channel_dim = channel_dim ctx.grad_scale = grad_scale - ctx.save_for_backward(x.detach(), - coeffs.detach(), - direction.detach()) + ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach()) return x - @staticmethod def backward(ctx, x_grad, *args): with torch.enable_grad(): @@ -311,15 +328,20 @@ class MaxEigLimiterFunction(torch.autograd.Function): x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) new_direction.requires_grad = False x = x - x.mean(dim=0) - x_var = (x ** 2).mean() + x_var = (x**2).mean() x_residual = x - coeffs * new_direction - x_residual_var = (x_residual ** 2).mean() + x_residual_var = (x_residual**2).mean() # `variance_proportion` is the proportion of the variance accounted for # by the top eigen-direction. This is to be minimized. variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) variance_proportion.backward() x_orig_grad = x_orig.grad - x_extra_grad = x_orig.grad * ctx.grad_scale * x_grad.norm() / (x_orig_grad.norm() + 1.0e-20) + x_extra_grad = ( + x_orig.grad + * ctx.grad_scale + * x_grad.norm() + / (x_orig_grad.norm() + 1.0e-20) + ) return x_grad + x_extra_grad.detach(), None, None, None, None @@ -385,15 +407,12 @@ class BasicNorm(torch.nn.Module): # region if it happens to exit it. eps = eps.clamp(min=self.eps_min, max=self.eps_max) scales = ( - torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp() + torch.mean(x**2, dim=self.channel_dim, keepdim=True) + eps.exp() ) ** -0.5 return x * scales - -def ScaledLinear(*args, - initial_scale: float = 1.0, - **kwargs ) -> nn.Linear: +def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: """ Behaves like a constructor of a modified version of nn.Linear that gives an easy way to set the default initial parameter scale. @@ -412,16 +431,11 @@ def ScaledLinear(*args, with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) return ans - -def ScaledConv1d(*args, - initial_scale: float = 1.0, - **kwargs ) -> nn.Conv1d: +def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d: """ Behaves like a constructor of a modified version of nn.Conv1d that gives an easy way to set the default initial parameter scale. @@ -440,13 +454,10 @@ def ScaledConv1d(*args, with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) return ans - class ActivationBalancer(torch.nn.Module): """ Modifies the backpropped derivatives of a function to try to encourage, for @@ -486,18 +497,19 @@ class ActivationBalancer(torch.nn.Module): from doing it at the same time. Early in training we may use higher probabilities than this; it will decay to this value. """ + def __init__( - self, - num_channels: int, - channel_dim: int, - min_positive: float = 0.05, - max_positive: float = 0.95, - max_factor: float = 0.04, - sign_gain_factor: float = 0.01, - scale_gain_factor: float = 0.02, - min_abs: float = 0.2, - max_abs: float = 100.0, - min_prob: float = 0.1, + self, + num_channels: int, + channel_dim: int, + min_positive: float = 0.05, + max_positive: float = 0.95, + max_factor: float = 0.04, + sign_gain_factor: float = 0.01, + scale_gain_factor: float = 0.02, + min_abs: float = 0.2, + max_abs: float = 100.0, + min_prob: float = 0.1, ): super(ActivationBalancer, self).__init__() self.num_channels = num_channels @@ -515,9 +527,7 @@ class ActivationBalancer(torch.nn.Module): # We occasionally sync this to a tensor called `count`, that exists to # make sure it is synced to disk when we load and save the model. self.cpu_count = 0 - self.register_buffer('count', torch.tensor(0, dtype=torch.int64)) - - + self.register_buffer("count", torch.tensor(0, dtype=torch.int64)) def forward(self, x: Tensor) -> Tensor: if torch.jit.is_scripting() or not x.requires_grad: @@ -535,26 +545,35 @@ class ActivationBalancer(torch.nn.Module): # the prob of doing some work exponentially decreases from 0.5 till it hits # a floor at min_prob (==0.1, by default) - prob = max(self.min_prob, 0.5 ** (1 + (count/4000.0))) + prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0))) if random.random() < prob: sign_gain_factor = 0.5 if self.min_positive != 0.0 or self.max_positive != 1.0: - sign_factor = _compute_sign_factor(x, self.channel_dim, - self.min_positive, self.max_positive, - gain_factor=self.sign_gain_factor / prob, - max_factor=self.max_factor) + sign_factor = _compute_sign_factor( + x, + self.channel_dim, + self.min_positive, + self.max_positive, + gain_factor=self.sign_gain_factor / prob, + max_factor=self.max_factor, + ) else: sign_factor = None - - scale_factor = _compute_scale_factor(x, self.channel_dim, - min_abs=self.min_abs, - max_abs=self.max_abs, - gain_factor=self.scale_gain_factor / prob, - max_factor=self.max_factor) + scale_factor = _compute_scale_factor( + x, + self.channel_dim, + min_abs=self.min_abs, + max_abs=self.max_abs, + gain_factor=self.scale_gain_factor / prob, + max_factor=self.max_factor, + ) return ActivationBalancerFunction.apply( - x, scale_factor, sign_factor, self.channel_dim, + x, + scale_factor, + sign_factor, + self.channel_dim, ) else: return _no_op(x) @@ -594,13 +613,12 @@ def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. else: (batch, dim, dim) = x.shape x = x.reshape(batch, dim * dim) - x = x[:, ::dim+1] + x = x[:, :: dim + 1] assert x.shape == (batch, dim) return x -def _whitening_metric(x: Tensor, - num_groups: int): +def _whitening_metric(x: Tensor, num_groups: int): """ Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of of the centered feature covariance are the same within each group's covariance matrix @@ -630,19 +648,21 @@ def _whitening_metric(x: Tensor, # the following expression is what we'd get if we took the matrix product # of each covariance and measured the mean of its trace, i.e. # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). - x_covarsq_mean_diag = (x_covar ** 2).sum() / (num_groups * channels_per_group) + x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group) # this metric will be >= 1.0; the larger it is, the less 'white' the data was. - metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20) + metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20) return metric class WhiteningPenaltyFunction(torch.autograd.Function): @staticmethod - def forward(ctx, - x: Tensor, - num_groups: int, - whitening_limit: float, - grad_scale: float) -> Tensor: + def forward( + ctx, + x: Tensor, + num_groups: int, + whitening_limit: float, + grad_scale: float, + ) -> Tensor: ctx.save_for_backward(x) ctx.num_groups = num_groups ctx.whitening_limit = whitening_limit @@ -650,9 +670,8 @@ class WhiteningPenaltyFunction(torch.autograd.Function): return x @staticmethod - def backward(ctx, - x_grad: Tensor): - x_orig, = ctx.saved_tensors + def backward(ctx, x_grad: Tensor): + (x_orig,) = ctx.saved_tensors with torch.enable_grad(): with torch.cuda.amp.autocast(enabled=False): x_detached = x_orig.to(torch.float32).detach() @@ -661,25 +680,29 @@ class WhiteningPenaltyFunction(torch.autograd.Function): metric = _whitening_metric(x_detached, ctx.num_groups) if random.random() < 0.005 or __name__ == "__main__": - logging.info(f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, " - f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}") + logging.info( + f"Whitening: num_groups={ctx.num_groups}," + f" num_channels={x_orig.shape[-1]}," + f" metric={metric.item():.2f} vs. limit={ctx.whitening_limit}" + ) (metric - ctx.whitening_limit).relu().backward() penalty_grad = x_detached.grad - scale = ctx.grad_scale * (x_grad.to(torch.float32).norm() / - (penalty_grad.norm() + 1.0e-20)) + scale = ctx.grad_scale * ( + x_grad.to(torch.float32).norm() / (penalty_grad.norm() + 1.0e-20) + ) penalty_grad = penalty_grad * scale return x_grad + penalty_grad.to(x_grad.dtype), None, None, None - class Whiten(nn.Module): def __init__( - self, - num_groups: int, - whitening_limit: float, - prob: Union[float, Tuple[float,float]], - grad_scale: float): + self, + num_groups: int, + whitening_limit: float, + prob: Union[float, Tuple[float, float]], + grad_scale: float, + ): """ Args: num_groups: the number of groups to divide the channel dim into before @@ -714,8 +737,7 @@ class Whiten(nn.Module): self.grad_scale = grad_scale - def forward(self, - x: Tensor) -> Tensor: + def forward(self, x: Tensor) -> Tensor: """ In the forward pass, this function just returns the input unmodified. In the backward pass, it will modify the gradients to ensure that the @@ -735,19 +757,21 @@ class Whiten(nn.Module): if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0: return _no_op(x) else: - if hasattr(self, 'min_prob') and random.random() < 0.25: + if hasattr(self, "min_prob") and random.random() < 0.25: # occasionally switch between min_prob and max_prob, based on whether # we are above or below the threshold. - if _whitening_metric(x.to(torch.float32), self.num_groups) > self.whitening_limit: + if ( + _whitening_metric(x.to(torch.float32), self.num_groups) + > self.whitening_limit + ): # there would be a change to the grad. self.prob = self.max_prob else: self.prob = self.min_prob - return WhiteningPenaltyFunction.apply(x, - self.num_groups, - self.whitening_limit, - self.grad_scale) + return WhiteningPenaltyFunction.apply( + x, self.num_groups, self.whitening_limit, self.grad_scale + ) class WithLoss(torch.autograd.Function): @@ -755,11 +779,14 @@ class WithLoss(torch.autograd.Function): def forward(ctx, x: Tensor, y: Tensor): ctx.y_shape = y.shape return x + @staticmethod def backward(ctx, ans_grad: Tensor): - return ans_grad, torch.ones(ctx.y_shape, - dtype=ans_grad.dtype, - device=ans_grad.device) + return ans_grad, torch.ones( + ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device + ) + + def with_loss(x, y): if torch.jit.is_scripting(): return x @@ -768,7 +795,7 @@ def with_loss(x, y): def _no_op(x: Tensor) -> Tensor: - if (torch.jit.is_scripting()): + if torch.jit.is_scripting(): return x else: # a no-op function that will have a node in the autograd graph, @@ -783,6 +810,7 @@ class Identity(torch.nn.Module): def forward(self, x): return _no_op(x) + class MaxEig(torch.nn.Module): """ Modifies the backpropped derivatives of a function to try to discourage @@ -803,13 +831,14 @@ class MaxEig(torch.nn.Module): scale: determines the scale with which we modify the gradients, relative to the existing / unmodified gradients """ + def __init__( - self, - num_channels: int, - channel_dim: int, - max_var_per_eig: float = 0.2, - min_prob: float = 0.01, - scale: float = 0.01, + self, + num_channels: int, + channel_dim: int, + max_var_per_eig: float = 0.2, + min_prob: float = 0.01, + scale: float = 0.01, ): super(MaxEig, self).__init__() self.num_channels = num_channels @@ -825,7 +854,7 @@ class MaxEig(torch.nn.Module): # random parameters unchanged for comparison direction = torch.arange(num_channels).to(torch.float) direction = direction / direction.norm() - self.register_buffer('max_eig_direction', direction) + self.register_buffer("max_eig_direction", direction) self.min_prob = min_prob # cur_prob is the current probability we'll use to apply the ActivationBalancer. @@ -833,12 +862,12 @@ class MaxEig(torch.nn.Module): # active. self.cur_prob = 1.0 - - def forward(self, x: Tensor) -> Tensor: - if (torch.jit.is_scripting() or - self.max_var_per_eig <= 0 or - random.random() > self.cur_prob): + if ( + torch.jit.is_scripting() + or self.max_var_per_eig <= 0 + or random.random() > self.cur_prob + ): return _no_op(x) with torch.cuda.amp.autocast(enabled=False): @@ -848,7 +877,9 @@ class MaxEig(torch.nn.Module): with torch.no_grad(): x = x.transpose(self.channel_dim, -1).reshape(-1, self.num_channels) x = x - x.mean(dim=0) - new_direction, coeffs = self._find_direction_coeffs(x, self.max_eig_direction) + new_direction, coeffs = self._find_direction_coeffs( + x, self.max_eig_direction + ) x_var = (x**2).mean() x_residual = x - coeffs * new_direction x_residual_var = (x_residual**2).mean() @@ -861,7 +892,10 @@ class MaxEig(torch.nn.Module): self._set_direction(0.1 * self.max_eig_direction + new_direction) if random.random() < 0.01 or __name__ == "__main__": - logging.info(f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}") + logging.info( + f"variance_proportion = {variance_proportion.item()}," + f" shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}" + ) if variance_proportion >= self.max_var_per_eig: # The constraint is active. Note, we should quite rarely @@ -869,17 +903,16 @@ class MaxEig(torch.nn.Module): # starting to diverge, should this constraint be active. cur_prob = self.cur_prob self.cur_prob = 1.0 # next time, do the update with probability 1.0. - return MaxEigLimiterFunction.apply(orig_x, coeffs, new_direction, - self.channel_dim, self.scale) + return MaxEigLimiterFunction.apply( + orig_x, coeffs, new_direction, self.channel_dim, self.scale + ) else: # let self.cur_prob exponentially approach self.min_prob, as # long as the constraint is inactive. self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob return orig_x - - def _set_direction(self, - direction: Tensor): + def _set_direction(self, direction: Tensor): """ Sets self.max_eig_direction to a normalized version of `direction` """ @@ -889,40 +922,39 @@ class MaxEig(torch.nn.Module): if direction_sum - direction_sum == 0: # no inf/nan self.max_eig_direction[:] = direction else: - logging.info(f"Warning: sum of direction in MaxEig is {direction_sum}, " - "num_channels={self.num_channels}, channel_dim={self.channel_dim}") + logging.info( + f"Warning: sum of direction in MaxEig is {direction_sum}, " + "num_channels={self.num_channels}, channel_dim={self.channel_dim}" + ) - - def _find_direction_coeffs(self, - x: Tensor, - prev_direction: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + def _find_direction_coeffs( + self, x: Tensor, prev_direction: Tensor + ) -> Tuple[Tensor, Tensor, Tensor]: """ - Figure out (an approximation to) the proportion of the variance of a set of - feature vectors that can be attributed to the top eigen-direction. - Args: - x: a Tensor of shape (num_frames, num_channels), with num_frames > 1. - prev_direction: a Tensor of shape (num_channels,), that is our previous estimate - of the top eigen-direction, or a random direction if this is the first - iteration. Does not have to be normalized, but should be nonzero. + Figure out (an approximation to) the proportion of the variance of a set of + feature vectors that can be attributed to the top eigen-direction. + Args: + x: a Tensor of shape (num_frames, num_channels), with num_frames > 1. + prev_direction: a Tensor of shape (num_channels,), that is our previous estimate + of the top eigen-direction, or a random direction if this is the first + iteration. Does not have to be normalized, but should be nonzero. - Returns: (cur_direction, coeffs), where: - cur_direction: a Tensor of shape (num_channels,) that is the current - estimate of the top eigen-direction. - coeffs: a Tensor of shape (num_frames, 1) that minimizes, or - approximately minimizes, (x - coeffs * cur_direction).norm() - """ + Returns: (cur_direction, coeffs), where: + cur_direction: a Tensor of shape (num_channels,) that is the current + estimate of the top eigen-direction. + coeffs: a Tensor of shape (num_frames, 1) that minimizes, or + approximately minimizes, (x - coeffs * cur_direction).norm() + """ (num_frames, num_channels) = x.shape assert num_channels > 1 and num_frames > 1 assert prev_direction.shape == (num_channels,) # `coeffs` are the coefficients of `prev_direction` in x. # actually represent the coeffs up to a constant positive factor. coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10 - cur_direction = (x * coeffs).sum(dim=0) / ((coeffs ** 2).sum() + 1.0e-20) + cur_direction = (x * coeffs).sum(dim=0) / ((coeffs**2).sum() + 1.0e-20) return cur_direction, coeffs - - class DoubleSwishFunction(torch.autograd.Function): """ double_swish(x) = x * torch.sigmoid(x-1) @@ -950,7 +982,7 @@ class DoubleSwishFunction(torch.autograd.Function): y = x * s if requires_grad: - deriv = (y * (1 - s) + s) + deriv = y * (1 - s) + s # notes on derivative of x * sigmoid(x - 1): # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund @@ -959,7 +991,9 @@ class DoubleSwishFunction(torch.autograd.Function): # floors), should be expectation-preserving. floor = -0.043637 ceil = 1.2 - d_scaled = ((deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv)) + d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like( + deriv + ) if __name__ == "__main__": # for self-testing only. assert d_scaled.min() >= 0.0 @@ -972,12 +1006,12 @@ class DoubleSwishFunction(torch.autograd.Function): @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: - d, = ctx.saved_tensors + (d,) = ctx.saved_tensors # the same constants as used in forward pass. floor = -0.043637 ceil = 1.2 - d = (d * ((ceil - floor) / 255.0) + floor) - return (y_grad * d) + d = d * ((ceil - floor) / 255.0) + floor + return y_grad * d class DoubleSwish(torch.nn.Module): @@ -990,7 +1024,6 @@ class DoubleSwish(torch.nn.Module): return DoubleSwishFunction.apply(x) - def _test_max_eig(): for proportion in [0.1, 0.5, 10.0]: logging.info(f"proportion = {proportion}") @@ -1002,11 +1035,9 @@ def _test_max_eig(): x.requires_grad = True num_channels = 128 - m = MaxEig(num_channels, - 1, # channel_dim - 0.5, # max_var_per_eig - scale=0.1) # grad_scale - + m = MaxEig( + num_channels, 1, 0.5, scale=0.1 # channel_dim # max_var_per_eig + ) # grad_scale for _ in range(4): y = m(x) @@ -1031,11 +1062,9 @@ def _test_whiten(): x.requires_grad = True num_channels = 128 - m = Whiten(1, # num_groups - 5.0, # whitening_limit, - prob=1.0, - grad_scale=0.1) # grad_scale - + m = Whiten( + 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit, + ) # grad_scale for _ in range(4): y = m(x) @@ -1049,7 +1078,6 @@ def _test_whiten(): assert not torch.allclose(x.grad, y_grad) - def _test_activation_balancer_sign(): probs = torch.arange(0, 1, 0.01) N = 1000 @@ -1077,9 +1105,7 @@ def _test_activation_balancer_sign(): def _test_activation_balancer_magnitude(): magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( - -1 - ) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True m = ActivationBalancer( @@ -1111,8 +1137,8 @@ def _test_basic_norm(): y = m(x) assert y.shape == x.shape - x_rms = (x ** 2).mean().sqrt() - y_rms = (y ** 2).mean().sqrt() + x_rms = (x**2).mean().sqrt() + y_rms = (y**2).mean().sqrt() print("x rms = ", x_rms) print("y rms = ", y_rms) assert y_rms < x_rms @@ -1124,30 +1150,27 @@ def _test_double_swish_deriv(): x.requires_grad = True m = DoubleSwish() - tol = ((1.2-(-0.043637))/255.0) + tol = (1.2 - (-0.043637)) / 255.0 torch.autograd.gradcheck(m, x, atol=tol) - # for self-test. x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 x.requires_grad = True y = m(x) - def _test_softmax(): a = torch.randn(2, 10, dtype=torch.float64) b = a.clone() a.requires_grad = True b.requires_grad = True - a.softmax(dim=1)[:,0].sum().backward() + a.softmax(dim=1)[:, 0].sum().backward() print("a grad = ", a.grad) - softmax(b, dim=1)[:,0].sum().backward() + softmax(b, dim=1)[:, 0].sum().backward() print("b grad = ", b.grad) assert torch.allclose(a.grad, b.grad) - if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py index 8d357b15f..46e775285 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py @@ -26,11 +26,7 @@ from typing import List import torch import torch.nn as nn -from scaling import ( - ActivationBalancer, - BasicNorm, - Whiten, -) +from scaling import ActivationBalancer, BasicNorm, Whiten class NonScaledNorm(nn.Module): @@ -75,12 +71,10 @@ def get_submodule(model, target): mod: torch.nn.Module = model for item in atoms: if not hasattr(mod, item): - raise AttributeError( - mod._get_name() + " has no " "attribute `" + item + "`" - ) + raise AttributeError(mod._get_name() + " has no attribute `" + item + "`") mod = getattr(mod, item) if not isinstance(mod, torch.nn.Module): - raise AttributeError("`" + item + "` is not " "an nn.Module") + raise AttributeError("`" + item + "` is not an nn.Module") return mod diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 3f27736b3..7f9526104 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -84,9 +84,7 @@ from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: @@ -124,7 +122,10 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--encoder-dims", type=str, default="384,384,384,384,384", - help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + help=( + "Embedding dimension in the 2 blocks of zipformer encoder layers, comma" + " separated" + ), ) parser.add_argument( @@ -139,9 +140,11 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--encoder-unmasked-dims", type=str, default="256,256,256,256,256", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " - " worse.", + help=( + "Unmasked dimensions in the encoders, relates to augmentation during" + " training. Must be <= each of encoder_dims. Empirically, less than 256" + " seems to make performance worse." + ), ) parser.add_argument( @@ -269,42 +272,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -646,11 +652,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -697,9 +699,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -870,9 +870,7 @@ def train_one_epoch( # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or ( - cur_grad_scale < 8.0 and batch_idx % 400 == 0 - ): + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") @@ -890,11 +888,7 @@ def train_one_epoch( f"batch {batch_idx}, loss[{loss_info}], " f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"lr: {cur_lr:.2e}, " - + ( - f"grad_scale: {scaler._scale.item()}" - if params.use_fp16 - else "" - ) + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") ) if tb_writer is not None: @@ -905,9 +899,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if params.use_fp16: tb_writer.add_scalar( "train/grad_scale", @@ -915,10 +907,7 @@ def train_one_epoch( params.batch_idx_train, ) - if ( - batch_idx % params.valid_interval == 0 - and not params.print_diagnostics - ): + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: logging.info("Computing validation loss") valid_info = compute_validation_loss( params=params, @@ -930,7 +919,8 @@ def train_one_epoch( model.train() logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + "Maximum memory allocated so far is" + f" {torch.cuda.max_memory_allocated()//1000000}MB" ) if tb_writer is not None: valid_info.write_summary( @@ -1009,9 +999,7 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) - optimizer = ScaledAdam( - model.parameters(), lr=params.base_lr, clipping_scale=2.0 - ) + optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) @@ -1029,7 +1017,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1054,8 +1042,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False @@ -1229,7 +1216,8 @@ def scan_pessimistic_batches_for_oom( display_and_save_batch(batch, params=params, sp=sp) raise logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + "Maximum memory allocated so far is" + f" {torch.cuda.max_memory_allocated()//1000000}MB" ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 023dec97d..fcd9858cd 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -16,32 +16,35 @@ # limitations under the License. import copy -import math -import warnings import itertools -from typing import List, Optional, Tuple, Union import logging -import torch +import math import random +import warnings +from typing import List, Optional, Tuple, Union + +import torch from encoder_interface import EncoderInterface +from scaling import ( + ScaledLinear, # not as in other dirs.. just scales down initial parameter values. +) from scaling import ( ActivationBalancer, BasicNorm, - MaxEig, DoubleSwish, - ScaledConv1d, - ScaledLinear, # not as in other dirs.. just scales down initial parameter values. - Whiten, Identity, + MaxEig, + ScaledConv1d, + Whiten, _diag, - random_clamp, penalize_abs_values_gt, + random_clamp, softmax, ) from torch import Tensor, nn -from icefall.utils import make_pad_mask from icefall.dist import get_rank +from icefall.utils import make_pad_mask class Zipformer(EncoderInterface): @@ -89,7 +92,7 @@ class Zipformer(EncoderInterface): self.batch_count = 0 self.warmup_end = warmup_batches - for u,d in zip(encoder_unmasked_dims, encoder_dims): + for u, d in zip(encoder_unmasked_dims, encoder_dims): assert u <= d, (u, d) # self.encoder_embed converts the input of shape (N, T, num_features) @@ -97,9 +100,9 @@ class Zipformer(EncoderInterface): # That is, it does two things simultaneously: # (1) subsampling: T -> (T - 7)//2 # (2) embedding: num_features -> encoder_dims - self.encoder_embed = Conv2dSubsampling(num_features, encoder_dims[0], - dropout=dropout) - + self.encoder_embed = Conv2dSubsampling( + num_features, encoder_dims[0], dropout=dropout + ) # each one will be ZipformerEncoder or DownsampledZipformerEncoder encoders = [] @@ -123,13 +126,13 @@ class Zipformer(EncoderInterface): num_encoder_layers[i], dropout, warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), - warmup_end=warmup_batches * (i + 2) / (num_encoders + 1) + warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), ) if zipformer_downsampling_factors[i] != 1: encoder = DownsampledZipformerEncoder( encoder, - input_dim=encoder_dims[i-1] if i > 0 else encoder_dims[0], + input_dim=encoder_dims[i - 1] if i > 0 else encoder_dims[0], output_dim=encoder_dims[i], downsample=zipformer_downsampling_factors[i], ) @@ -139,10 +142,11 @@ class Zipformer(EncoderInterface): # initializes self.skip_layers and self.skip_modules self._init_skip_modules() - self.downsample_output = AttentionDownsample(encoder_dims[-1], - encoder_dims[-1], - downsample=output_downsampling_factor) - + self.downsample_output = AttentionDownsample( + encoder_dims[-1], + encoder_dims[-1], + downsample=output_downsampling_factor, + ) def _get_layer_skip_dropout_prob(self): if not self.training: @@ -166,27 +170,33 @@ class Zipformer(EncoderInterface): skip_modules = [] z = self.zipformer_downsampling_factors for i in range(len(z)): - if i <= 1 or z[i-1] <= z[i]: + if i <= 1 or z[i - 1] <= z[i]: skip_layers.append(None) skip_modules.append(SimpleCombinerIdentity()) else: # TEMP - for j in range(i-2, -1, -1): + for j in range(i - 2, -1, -1): if z[j] <= z[i] or j == 0: # TEMP logging statement. - logging.info(f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will " - f"combine the outputs of layers {j} and {i-1}, with downsampling_factors={z[j]} and {z[i-1]}.") + logging.info( + f"At encoder stack {i}, which has" + f" downsampling_factor={z[i]}, we will combine the outputs" + f" of layers {j} and {i-1}, with" + f" downsampling_factors={z[j]} and {z[i-1]}." + ) skip_layers.append(j) - skip_modules.append(SimpleCombiner(self.encoder_dims[j], - self.encoder_dims[i-1], - min_weight=(0.0,0.25))) + skip_modules.append( + SimpleCombiner( + self.encoder_dims[j], + self.encoder_dims[i - 1], + min_weight=(0.0, 0.25), + ) + ) break self.skip_layers = skip_layers self.skip_modules = nn.ModuleList(skip_modules) - def get_feature_masks( - self, - x: torch.Tensor) -> List[float]: + def get_feature_masks(self, x: torch.Tensor) -> List[float]: # Note: The actual return type is Union[List[float], List[Tensor]], # but to make torch.jit.script() work, we use List[float] """ @@ -206,46 +216,56 @@ class Zipformer(EncoderInterface): """ num_encoders = len(self.encoder_dims) if torch.jit.is_scripting() or not self.training: - return [ 1.0 ] * num_encoders + return [1.0] * num_encoders (num_frames0, batch_size, _encoder_dims0) = x.shape - - assert self.encoder_dims[0] == _encoder_dims0, (self.encoder_dims, _encoder_dims0) + assert self.encoder_dims[0] == _encoder_dims0, ( + self.encoder_dims, + _encoder_dims0, + ) max_downsampling_factor = max(self.zipformer_downsampling_factors) - num_frames_max = (num_frames0 + max_downsampling_factor - 1) - + num_frames_max = num_frames0 + max_downsampling_factor - 1 feature_mask_dropout_prob = 0.15 # frame_mask_max shape: (num_frames_max, batch_size, 1) - frame_mask_max = (torch.rand(num_frames_max, batch_size, 1, - device=x.device) > - feature_mask_dropout_prob).to(x.dtype) + frame_mask_max = ( + torch.rand(num_frames_max, batch_size, 1, device=x.device) + > feature_mask_dropout_prob + ).to(x.dtype) feature_masks = [] for i in range(num_encoders): ds = self.zipformer_downsampling_factors[i] - upsample_factor = (max_downsampling_factor // ds) + upsample_factor = max_downsampling_factor // ds - frame_mask = (frame_mask_max.unsqueeze(1).expand(num_frames_max, upsample_factor, - batch_size, 1) - .reshape(num_frames_max * upsample_factor, batch_size, 1)) + frame_mask = ( + frame_mask_max.unsqueeze(1) + .expand(num_frames_max, upsample_factor, batch_size, 1) + .reshape(num_frames_max * upsample_factor, batch_size, 1) + ) num_frames = (num_frames0 + ds - 1) // ds frame_mask = frame_mask[:num_frames] - feature_mask = torch.ones(num_frames, batch_size, self.encoder_dims[i], - dtype=x.dtype, device=x.device) + feature_mask = torch.ones( + num_frames, + batch_size, + self.encoder_dims[i], + dtype=x.dtype, + device=x.device, + ) u = self.encoder_unmasked_dims[i] feature_mask[:, :, u:] *= frame_mask feature_masks.append(feature_mask) return feature_masks - def forward( - self, x: torch.Tensor, x_lens: torch.Tensor, + self, + x: torch.Tensor, + x_lens: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -265,13 +285,19 @@ class Zipformer(EncoderInterface): x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) lengths = (x_lens - 7) >> 1 - assert x.size(0) == lengths.max().item(), (x.shape, lengths, lengths.max()) + assert x.size(0) == lengths.max().item(), ( + x.shape, + lengths, + lengths.max(), + ) mask = make_pad_mask(lengths) outputs = [] feature_masks = self.get_feature_masks(x) - for i, (module, skip_module) in enumerate(zip(self.encoders, self.skip_modules)): + for i, (module, skip_module) in enumerate( + zip(self.encoders, self.skip_modules) + ): ds = self.zipformer_downsampling_factors[i] k = self.skip_layers[i] if isinstance(k, int): @@ -280,9 +306,11 @@ class Zipformer(EncoderInterface): x = skip_module(outputs[k], x) elif (not self.training) or random.random() > layer_skip_dropout_prob: x = skip_module(outputs[k], x) - x = module(x, - feature_mask=feature_masks[i], - src_key_padding_mask=None if mask is None else mask[...,::ds]) + x = module( + x, + feature_mask=feature_masks[i], + src_key_padding_mask=None if mask is None else mask[..., ::ds], + ) outputs.append(x) x = self.downsample_output(x) @@ -312,15 +340,16 @@ class ZipformerEncoderLayer(nn.Module): >>> pos_emb = torch.rand(32, 19, 512) >>> out = encoder_layer(src, pos_emb) """ + def __init__( - self, - d_model: int, - attention_dim: int, - nhead: int, - feedforward_dim: int = 2048, - dropout: float = 0.1, - cnn_module_kernel: int = 31, - pos_dim: int = 4, + self, + d_model: int, + attention_dim: int, + nhead: int, + feedforward_dim: int = 2048, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + pos_dim: int = 4, ) -> None: super(ZipformerEncoderLayer, self).__init__() @@ -330,29 +359,24 @@ class ZipformerEncoderLayer(nn.Module): self.batch_count = 0 self.self_attn = RelPositionMultiheadAttention( - d_model, attention_dim, nhead, pos_dim, dropout=0.0, + d_model, + attention_dim, + nhead, + pos_dim, + dropout=0.0, ) self.pooling = PoolingModule(d_model) - self.feed_forward1 = FeedforwardModule(d_model, - feedforward_dim, - dropout) + self.feed_forward1 = FeedforwardModule(d_model, feedforward_dim, dropout) - self.feed_forward2 = FeedforwardModule(d_model, - feedforward_dim, - dropout) + self.feed_forward2 = FeedforwardModule(d_model, feedforward_dim, dropout) - self.feed_forward3 = FeedforwardModule(d_model, - feedforward_dim, - dropout) + self.feed_forward3 = FeedforwardModule(d_model, feedforward_dim, dropout) + self.conv_module1 = ConvolutionModule(d_model, cnn_module_kernel) - self.conv_module1 = ConvolutionModule(d_model, - cnn_module_kernel) - - self.conv_module2 = ConvolutionModule(d_model, - cnn_module_kernel) + self.conv_module2 = ConvolutionModule(d_model, cnn_module_kernel) self.norm_final = BasicNorm(d_model) @@ -360,14 +384,18 @@ class ZipformerEncoderLayer(nn.Module): # try to ensure the output is close to zero-mean (or at least, zero-median). self.balancer = ActivationBalancer( - d_model, channel_dim=-1, - min_positive=0.45, max_positive=0.55, + d_model, + channel_dim=-1, + min_positive=0.45, + max_positive=0.55, max_abs=6.0, ) - self.whiten = Whiten(num_groups=1, - whitening_limit=5.0, - prob=(0.025, 0.25), - grad_scale=0.01) + self.whiten = Whiten( + num_groups=1, + whitening_limit=5.0, + prob=(0.025, 0.25), + grad_scale=0.01, + ) def get_bypass_scale(self): if torch.jit.is_scripting() or not self.training: @@ -382,8 +410,9 @@ class ZipformerEncoderLayer(nn.Module): if self.batch_count > warmup_period: clamp_min = final_clamp_min else: - clamp_min = (initial_clamp_min - - (self.batch_count / warmup_period) * (initial_clamp_min - final_clamp_min)) + clamp_min = initial_clamp_min - (self.batch_count / warmup_period) * ( + initial_clamp_min - final_clamp_min + ) return self.bypass_scale.clamp(min=clamp_min, max=1.0) def get_dynamic_dropout_rate(self): @@ -398,8 +427,9 @@ class ZipformerEncoderLayer(nn.Module): if self.batch_count > warmup_period: return final_dropout_rate else: - return (initial_dropout_rate - - (initial_dropout_rate * final_dropout_rate) * (self.batch_count / warmup_period)) + return initial_dropout_rate - ( + initial_dropout_rate * final_dropout_rate + ) * (self.batch_count / warmup_period) def forward( self, @@ -508,13 +538,14 @@ class ZipformerEncoder(nn.Module): >>> src = torch.rand(10, 32, 512) >>> out = zipformer_encoder(src) """ + def __init__( - self, - encoder_layer: nn.Module, - num_layers: int, - dropout: float, - warmup_begin: float, - warmup_end: float + self, + encoder_layer: nn.Module, + num_layers: int, + dropout: float, + warmup_begin: float, + warmup_end: float, ) -> None: super().__init__() # will be written to, see set_batch_count() Note: in inference time this @@ -528,8 +559,7 @@ class ZipformerEncoder(nn.Module): # so that we can keep this consistent across worker tasks (for efficiency). self.module_seed = torch.randint(0, 1000, ()).item() - self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model, - dropout) + self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model, dropout) self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for i in range(num_layers)] @@ -538,15 +568,13 @@ class ZipformerEncoder(nn.Module): assert 0 <= warmup_begin <= warmup_end, (warmup_begin, warmup_end) - - delta = (1. / num_layers) * (warmup_end - warmup_begin) + delta = (1.0 / num_layers) * (warmup_end - warmup_begin) cur_begin = warmup_begin for i in range(num_layers): self.layers[i].warmup_begin = cur_begin cur_begin += delta self.layers[i].warmup_end = cur_begin - def get_layers_to_drop(self, rnd_seed: int): ans = set() if not self.training: @@ -579,12 +607,14 @@ class ZipformerEncoder(nn.Module): # linearly interpolate t = (batch_count - layer_warmup_begin) / layer_warmup_end assert 0.0 <= t < 1.001, t - return initial_layerdrop_prob + t * (final_layerdrop_prob - initial_layerdrop_prob) + return initial_layerdrop_prob + t * ( + final_layerdrop_prob - initial_layerdrop_prob + ) shared_rng = random.Random(batch_count + self.module_seed) independent_rng = random.Random(rnd_seed) - layerdrop_probs = [ get_layerdrop_prob(i) for i in range(num_layers) ] + layerdrop_probs = [get_layerdrop_prob(i) for i in range(num_layers)] tot = sum(layerdrop_probs) # Instead of drawing the samples independently, we first randomly decide # how many layers to drop out, using the same random number generator between @@ -604,11 +634,13 @@ class ZipformerEncoder(nn.Module): if len(ans) == num_to_drop: break if shared_rng.random() < 0.005 or __name__ == "__main__": - logging.info(f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, " - f"batch_count={batch_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}") + logging.info( + f"warmup_begin={self.warmup_begin:.1f}," + f" warmup_end={self.warmup_end:.1f}, batch_count={batch_count:.1f}," + f" num_to_drop={num_to_drop}, layers_to_drop={ans}" + ) return ans - def forward( self, src: Tensor, @@ -639,7 +671,6 @@ class ZipformerEncoder(nn.Module): pos_emb = self.encoder_pos(src) output = src - if torch.jit.is_scripting(): layers_to_drop = [] else: @@ -670,28 +701,31 @@ class DownsampledZipformerEncoder(nn.Module): after convolutional downsampling, and then upsampled again at the output, and combined with the origin input, so that the output has the same shape as the input. """ - def __init__(self, - encoder: nn.Module, - input_dim: int, - output_dim: int, - downsample: int): + + def __init__( + self, + encoder: nn.Module, + input_dim: int, + output_dim: int, + downsample: int, + ): super(DownsampledZipformerEncoder, self).__init__() self.downsample_factor = downsample self.downsample = AttentionDownsample(input_dim, output_dim, downsample) self.encoder = encoder self.upsample = SimpleUpsample(output_dim, downsample) - self.out_combiner = SimpleCombiner(input_dim, - output_dim, - min_weight=(0.0, 0.25)) + self.out_combiner = SimpleCombiner( + input_dim, output_dim, min_weight=(0.0, 0.25) + ) - - def forward(self, - src: Tensor, - # Note: the type of feature_mask should be Unino[float, Tensor], - # but to make torch.jit.script() happ, we use float here - feature_mask: float = 1.0, - mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, + def forward( + self, + src: Tensor, + # Note: the type of feature_mask should be Unino[float, Tensor], + # but to make torch.jit.script() happ, we use float here + feature_mask: float = 1.0, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: r"""Downsample, go through encoder, upsample. @@ -718,42 +752,43 @@ class DownsampledZipformerEncoder(nn.Module): src = self.downsample(src) ds = self.downsample_factor if mask is not None: - mask = mask[::ds,::ds] + mask = mask[::ds, ::ds] src = self.encoder( - src, feature_mask=feature_mask, mask=mask, src_key_padding_mask=mask, + src, + feature_mask=feature_mask, + mask=mask, + src_key_padding_mask=mask, ) src = self.upsample(src) # remove any extra frames that are not a multiple of downsample_factor - src = src[:src_orig.shape[0]] + src = src[: src_orig.shape[0]] return self.out_combiner(src_orig, src) + class AttentionDownsample(torch.nn.Module): """ Does downsampling with attention, by weighted sum, and a projection.. """ - def __init__(self, - in_channels: int, - out_channels: int, - downsample: int): + + def __init__(self, in_channels: int, out_channels: int, downsample: int): """ Require out_channels > in_channels. """ super(AttentionDownsample, self).__init__() - self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5)) + self.query = nn.Parameter(torch.randn(in_channels) * (in_channels**-0.5)) # fill in the extra dimensions with a projection of the input if out_channels > in_channels: - self.extra_proj = nn.Linear(in_channels * downsample, - out_channels - in_channels, - bias=False) + self.extra_proj = nn.Linear( + in_channels * downsample, out_channels - in_channels, bias=False + ) else: self.extra_proj = None self.downsample = downsample - def forward(self, - src: Tensor) -> Tensor: + def forward(self, src: Tensor) -> Tensor: """ x: (seq_len, batch_size, in_channels) Returns a tensor of shape @@ -767,16 +802,14 @@ class AttentionDownsample(torch.nn.Module): if seq_len != d_seq_len * ds: # right-pad src, repeating the last element. pad = d_seq_len * ds - seq_len - src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2]) + src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) src = torch.cat((src, src_extra), dim=0) assert src.shape[0] == d_seq_len * ds, (src.shape[0], d_seq_len, ds) src = src.reshape(d_seq_len, ds, batch_size, in_channels) scores = (src * self.query).sum(dim=-1, keepdim=True) - scores = penalize_abs_values_gt(scores, - limit=10.0, - penalty=1.0e-04) + scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04) weights = scores.softmax(dim=1) @@ -795,14 +828,12 @@ class SimpleUpsample(torch.nn.Module): A very simple form of upsampling that mostly just repeats the input, but also adds a position-specific bias. """ - def __init__(self, - num_channels: int, - upsample: int): + + def __init__(self, num_channels: int, upsample: int): super(SimpleUpsample, self).__init__() self.bias = nn.Parameter(torch.randn(upsample, num_channels) * 0.01) - def forward(self, - src: Tensor) -> Tensor: + def forward(self, src: Tensor) -> Tensor: """ x: (seq_len, batch_size, num_channels) Returns a tensor of shape @@ -815,6 +846,7 @@ class SimpleUpsample(torch.nn.Module): src = src.reshape(seq_len * upsample, batch_size, num_channels) return src + class SimpleCombinerIdentity(nn.Module): def __init__(self, *args, **kwargs): super().__init__() @@ -822,6 +854,7 @@ class SimpleCombinerIdentity(nn.Module): def forward(self, src1: Tensor, src2: Tensor) -> Tensor: return src1 + class SimpleCombiner(torch.nn.Module): """ A very simple way of combining 2 vectors of 2 different dims, via a @@ -831,18 +864,14 @@ class SimpleCombiner(torch.nn.Module): dim2: the dimension of the second input, e.g. 384. The output will have the same dimension as dim2. """ - def __init__(self, - dim1: int, - dim2: int, - min_weight: Tuple[float] = (0., 0.)): + + def __init__(self, dim1: int, dim2: int, min_weight: Tuple[float] = (0.0, 0.0)): super(SimpleCombiner, self).__init__() assert dim2 >= dim1, (dim2, dim1) self.weight1 = nn.Parameter(torch.zeros(())) self.min_weight = min_weight - def forward(self, - src1: Tensor, - src2: Tensor) -> Tensor: + def forward(self, src1: Tensor, src2: Tensor) -> Tensor: """ src1: (*, dim1) src2: (*, dim2) @@ -853,10 +882,14 @@ class SimpleCombiner(torch.nn.Module): weight1 = self.weight1 if not torch.jit.is_scripting(): - if self.training and random.random() < 0.25 and self.min_weight != (0., 0.): - weight1 = weight1.clamp(min=self.min_weight[0], - max=1.0-self.min_weight[1]) - + if ( + self.training + and random.random() < 0.25 + and self.min_weight != (0.0, 0.0) + ): + weight1 = weight1.clamp( + min=self.min_weight[0], max=1.0 - self.min_weight[1] + ) src1 = src1 * weight1 src2 = src2 * (1.0 - weight1) @@ -869,12 +902,9 @@ class SimpleCombiner(torch.nn.Module): else: src1 = src1[:src2_dim] - return src1 + src2 - - class RelPositionalEncoding(torch.nn.Module): """Relative positional encoding module. @@ -888,9 +918,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct a PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -905,9 +933,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(0) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vecotr and `j` means the @@ -955,7 +981,6 @@ class RelPositionalEncoding(torch.nn.Module): return self.dropout(pos_emb) - class RelPositionMultiheadAttention(nn.Module): r"""Multi-Head Attention layer with relative position encoding @@ -992,34 +1017,46 @@ class RelPositionMultiheadAttention(nn.Module): self.head_dim = attention_dim // num_heads self.pos_dim = pos_dim assert self.head_dim % 2 == 0, self.head_dim - assert ( - self.head_dim * num_heads == attention_dim - ), (self.head_dim, num_heads, attention_dim) + assert self.head_dim * num_heads == attention_dim, ( + self.head_dim, + num_heads, + attention_dim, + ) # the initial_scale is supposed to take over the "scaling" factor of # head_dim ** -0.5, dividing it between the query and key. - in_proj_dim = (2 * attention_dim + # query, key - attention_dim // 2 + # value - pos_dim * num_heads) # positional encoding query + in_proj_dim = ( + 2 * attention_dim + + attention_dim // 2 # query, key + + pos_dim * num_heads # value + ) # positional encoding query - self.in_proj = ScaledLinear(embed_dim, in_proj_dim, bias=True, - initial_scale=self.head_dim**-0.25) + self.in_proj = ScaledLinear( + embed_dim, + in_proj_dim, + bias=True, + initial_scale=self.head_dim**-0.25, + ) # self.whiten_values is applied on the values in forward(); # it just copies the keys but prevents low-rank distribution by modifying grads. - self.whiten_values = Whiten(num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025) - self.whiten_keys = Whiten(num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025) - + self.whiten_values = Whiten( + num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025, + ) + self.whiten_keys = Whiten( + num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025, + ) # linear transformation for positional encoding. - self.linear_pos = ScaledLinear(embed_dim, num_heads * pos_dim, bias=False, - initial_scale=0.05) + self.linear_pos = ScaledLinear( + embed_dim, num_heads * pos_dim, bias=False, initial_scale=0.05 + ) # the following are for diagnosics only, see --print-diagnostics option. # they only copy their inputs. @@ -1031,14 +1068,16 @@ class RelPositionMultiheadAttention(nn.Module): ) self.in_proj2 = nn.Linear(embed_dim, attention_dim // 2, bias=False) - self.out_proj2 = ScaledLinear(attention_dim // 2, embed_dim, bias=True, - initial_scale=0.05) + self.out_proj2 = ScaledLinear( + attention_dim // 2, embed_dim, bias=True, initial_scale=0.05 + ) # self.whiten_values2 is applied on the values in forward2() - self.whiten_values2 = Whiten(num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025) - + self.whiten_values2 = Whiten( + num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025, + ) def forward( self, @@ -1098,7 +1137,6 @@ class RelPositionMultiheadAttention(nn.Module): ) return x, weights - def multi_head_attention_forward( self, x_proj: Tensor, @@ -1156,26 +1194,24 @@ class RelPositionMultiheadAttention(nn.Module): head_dim = attention_dim // num_heads pos_dim = self.pos_dim # positional-encoding dim per head - assert ( - head_dim * num_heads == attention_dim - ), f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}, {attention_dim}" - + assert head_dim * num_heads == attention_dim, ( + f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}," + f" {attention_dim}" + ) # self-attention - q = x_proj[...,0:attention_dim] - k = x_proj[...,attention_dim:2*attention_dim] + q = x_proj[..., 0:attention_dim] + k = x_proj[..., attention_dim : 2 * attention_dim] value_dim = attention_dim // 2 - v = x_proj[...,2*attention_dim:2*attention_dim+value_dim] + v = x_proj[..., 2 * attention_dim : 2 * attention_dim + value_dim] # p is the position-encoding query, its dimension is num_heads*pos_dim.. - p = x_proj[...,2*attention_dim+value_dim:] - + p = x_proj[..., 2 * attention_dim + value_dim :] k = self.whiten_keys(k) # does nothing in the forward pass. v = self.whiten_values(v) # does nothing in the forward pass. q = self.copy_query(q) # for diagnostics only, does nothing. p = self.copy_pos_query(p) # for diagnostics only, does nothing. - if attn_mask is not None: assert ( attn_mask.dtype == torch.float32 @@ -1195,33 +1231,25 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, seq_len, seq_len]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, seq_len, seq_len, ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor" + " instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -1230,7 +1258,6 @@ class RelPositionMultiheadAttention(nn.Module): k = k.reshape(seq_len, bsz, num_heads, head_dim) v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1) - if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz, "{} == {}".format( key_padding_mask.size(0), bsz @@ -1239,13 +1266,10 @@ class RelPositionMultiheadAttention(nn.Module): key_padding_mask.size(1), seq_len ) - - q = q.permute(1, 2, 0, 3) # (batch, head, time1, head_dim) p = p.permute(1, 2, 0, 3) # (batch, head, time1, pos_dim) k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - seq_len2 = 2 * seq_len - 1 pos = pos.reshape(1, seq_len2, num_heads, pos_dim).permute(0, 2, 3, 1) # pos shape now: (batch, head, pos_dim, seq_len2) @@ -1256,13 +1280,16 @@ class RelPositionMultiheadAttention(nn.Module): # the following .as_strided() expression converts the last axis of pos_weights from relative # to absolute position. I don't know whether I might have got the time-offsets backwards or # not, but let this code define which way round it is supposed to be. - pos_weights = pos_weights.as_strided((bsz, num_heads, seq_len, seq_len), - (pos_weights.stride(0), - pos_weights.stride(1), - pos_weights.stride(2)-pos_weights.stride(3), - pos_weights.stride(3)), - storage_offset=pos_weights.stride(3) * (seq_len - 1)) - + pos_weights = pos_weights.as_strided( + (bsz, num_heads, seq_len, seq_len), + ( + pos_weights.stride(0), + pos_weights.stride(1), + pos_weights.stride(2) - pos_weights.stride(3), + pos_weights.stride(3), + ), + storage_offset=pos_weights.stride(3) * (seq_len - 1), + ) # caution: they are really scores at this point. attn_output_weights = torch.matmul(q, k) + pos_weights @@ -1275,10 +1302,9 @@ class RelPositionMultiheadAttention(nn.Module): # this mechanism instead of, say, a limit on entropy, because once the entropy # gets very small gradients through the softmax can become very small, and # some mechanisms like that become ineffective. - attn_output_weights = penalize_abs_values_gt(attn_output_weights, - limit=25.0, - penalty=1.0e-04) - + attn_output_weights = penalize_abs_values_gt( + attn_output_weights, limit=25.0, penalty=1.0e-04 + ) # attn_output_weights: (batch, head, time1, time2) attn_output_weights = attn_output_weights.view( @@ -1320,20 +1346,20 @@ class RelPositionMultiheadAttention(nn.Module): ) attn_output = torch.bmm(attn_output_weights, v) - assert list(attn_output.size()) == [bsz * num_heads, seq_len, - head_dim // 2] + assert list(attn_output.size()) == [ + bsz * num_heads, + seq_len, + head_dim // 2, + ] attn_output = ( attn_output.transpose(0, 1) .contiguous() .view(seq_len, bsz, attention_dim // 2) ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias - ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) return attn_output, attn_output_weights - def forward2( self, x: Tensor, @@ -1372,11 +1398,7 @@ class RelPositionMultiheadAttention(nn.Module): # returned value is of shape (seq_len, bsz, embed_dim), like x. return self.out_proj2(attn_output) - - def _print_attn_stats( - self, - attn_weights: Tensor, - attn_output: Tensor): + def _print_attn_stats(self, attn_weights: Tensor, attn_output: Tensor): # attn_weights: (batch_size * num_heads, seq_len, seq_len) # attn_output: (bsz * num_heads, seq_len, head_dim) (n, seq_len, head_dim) = attn_output.shape @@ -1387,39 +1409,50 @@ class RelPositionMultiheadAttention(nn.Module): with torch.cuda.amp.autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_output = attn_output.to(torch.float32) - attn_weights_entropy = -((attn_weights + 1.0e-20).log() * attn_weights).sum( - dim=-1).reshape(bsz, num_heads, seq_len).mean(dim=(0,2)) + attn_weights_entropy = ( + -((attn_weights + 1.0e-20).log() * attn_weights) + .sum(dim=-1) + .reshape(bsz, num_heads, seq_len) + .mean(dim=(0, 2)) + ) attn_output = attn_output.reshape(bsz, num_heads, seq_len, head_dim) - attn_output = attn_output.permute(1, 0, 2, 3).reshape(num_heads, bsz * seq_len, head_dim) + attn_output = attn_output.permute(1, 0, 2, 3).reshape( + num_heads, bsz * seq_len, head_dim + ) attn_output_mean = attn_output.mean(dim=1, keepdim=True) attn_output = attn_output - attn_output_mean - attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / (bsz * seq_len) + attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / ( + bsz * seq_len + ) # attn_covar: (num_heads, head_dim, head_dim) - #eigs, _ = torch.symeig(attn_covar) - #logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") + # eigs, _ = torch.symeig(attn_covar) + # logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") attn_covar = _diag(attn_covar).mean(dim=1) # (num_heads,) embed_dim = self.in_proj2.weight.shape[1] - in_proj_covar = (self.in_proj2.weight.reshape(num_heads, head_dim, embed_dim) ** 2).mean(dim=(1,2)) - out_proj_covar = (self.out_proj2.weight.reshape(embed_dim, num_heads, head_dim) ** 2).mean(dim=(0,2)) - logging.info(f"attn_weights_entropy = {attn_weights_entropy}, covar={attn_covar}, in_proj_covar={in_proj_covar}, out_proj_covar={out_proj_covar}") - - + in_proj_covar = ( + self.in_proj2.weight.reshape(num_heads, head_dim, embed_dim) ** 2 + ).mean(dim=(1, 2)) + out_proj_covar = ( + self.out_proj2.weight.reshape(embed_dim, num_heads, head_dim) ** 2 + ).mean(dim=(0, 2)) + logging.info( + f"attn_weights_entropy = {attn_weights_entropy}," + f" covar={attn_covar}, in_proj_covar={in_proj_covar}," + f" out_proj_covar={out_proj_covar}" + ) class PoolingModule(nn.Module): """ Averages the input over the time dimension and project with a square matrix. """ - def __init__(self, - d_model: int): - super().__init__() - self.proj = ScaledLinear(d_model, d_model, - initial_scale=0.1, bias=False) - def forward(self, - x: Tensor, - key_padding_mask: Optional[Tensor] = None): + def __init__(self, d_model: int): + super().__init__() + self.proj = ScaledLinear(d_model, d_model, initial_scale=0.1, bias=False) + + def forward(self, x: Tensor, key_padding_mask: Optional[Tensor] = None): """ Args: x: a Tensor of shape (T, N, C) @@ -1430,7 +1463,7 @@ class PoolingModule(nn.Module): """ if key_padding_mask is not None: pooling_mask = key_padding_mask.logical_not().to(x.dtype) # (N, T) - pooling_mask = (pooling_mask / pooling_mask.sum(dim=1, keepdim=True)) + pooling_mask = pooling_mask / pooling_mask.sum(dim=1, keepdim=True) pooling_mask = pooling_mask.transpose(0, 1).contiguous().unsqueeze(-1) # now pooling_mask: (T, N, 1) x = (x * pooling_mask).sum(dim=0, keepdim=True) @@ -1444,24 +1477,19 @@ class PoolingModule(nn.Module): class FeedforwardModule(nn.Module): - """Feedforward module in Zipformer model. - """ - def __init__(self, - d_model: int, - feedforward_dim: int, - dropout: float): + """Feedforward module in Zipformer model.""" + + def __init__(self, d_model: int, feedforward_dim: int, dropout: float): super(FeedforwardModule, self).__init__() self.in_proj = nn.Linear(d_model, feedforward_dim) - self.balancer = ActivationBalancer(feedforward_dim, - channel_dim=-1, max_abs=10.0, - min_prob=0.25) + self.balancer = ActivationBalancer( + feedforward_dim, channel_dim=-1, max_abs=10.0, min_prob=0.25 + ) self.activation = DoubleSwish() self.dropout = nn.Dropout(dropout) - self.out_proj = ScaledLinear(feedforward_dim, d_model, - initial_scale=0.01) + self.out_proj = ScaledLinear(feedforward_dim, d_model, initial_scale=0.01) - def forward(self, - x: Tensor): + def forward(self, x: Tensor): x = self.in_proj(x) x = self.balancer(x) x = self.activation(x) @@ -1481,9 +1509,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding @@ -1513,7 +1539,10 @@ class ConvolutionModule(nn.Module): # the correct range. self.deriv_balancer1 = ActivationBalancer( 2 * channels, - channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0 + channel_dim=1, + max_abs=10.0, + min_positive=0.05, + max_positive=1.0, ) self.depthwise_conv = nn.Conv1d( @@ -1527,8 +1556,10 @@ class ConvolutionModule(nn.Module): ) self.deriv_balancer2 = ActivationBalancer( - channels, channel_dim=1, - min_positive=0.05, max_positive=1.0, + channels, + channel_dim=1, + min_positive=0.05, + max_positive=1.0, max_abs=20.0, ) @@ -1544,9 +1575,10 @@ class ConvolutionModule(nn.Module): initial_scale=0.05, ) - def forward(self, - x: Tensor, - src_key_padding_mask: Optional[Tensor] = None, + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: """Compute convolution module. @@ -1626,8 +1658,7 @@ class Conv2dSubsampling(nn.Module): kernel_size=3, padding=(0, 1), # (time, freq) ), - ActivationBalancer(layer1_channels, - channel_dim=1), + ActivationBalancer(layer1_channels, channel_dim=1), DoubleSwish(), nn.Conv2d( in_channels=layer1_channels, @@ -1636,24 +1667,21 @@ class Conv2dSubsampling(nn.Module): stride=2, padding=0, ), - ActivationBalancer(layer2_channels, - channel_dim=1), + ActivationBalancer(layer2_channels, channel_dim=1), DoubleSwish(), nn.Conv2d( in_channels=layer2_channels, out_channels=layer3_channels, kernel_size=3, - stride=(1, 2), # (time, freq) + stride=(1, 2), # (time, freq) ), - ActivationBalancer(layer3_channels, - channel_dim=1), + ActivationBalancer(layer3_channels, channel_dim=1), DoubleSwish(), ) out_height = (((in_channels - 1) // 2) - 1) // 2 self.out = ScaledLinear(out_height * layer3_channels, out_channels) self.dropout = nn.Dropout(dropout) - def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. @@ -1674,6 +1702,7 @@ class Conv2dSubsampling(nn.Module): x = self.dropout(x) return x + class AttentionCombine(nn.Module): """ This module combines a list of Tensors, all with the same shape, to @@ -1717,15 +1746,12 @@ class AttentionCombine(nn.Module): self.random_prob = random_prob self.single_prob = single_prob - self.weight = torch.nn.Parameter(torch.zeros(num_channels, - num_inputs)) + self.weight = torch.nn.Parameter(torch.zeros(num_channels, num_inputs)) self.bias = torch.nn.Parameter(torch.zeros(num_inputs)) assert 0 <= random_prob <= 1, random_prob assert 0 <= single_prob <= 1, single_prob - - def forward(self, inputs: List[Tensor]) -> Tensor: """Forward function. Args: @@ -1756,28 +1782,35 @@ class AttentionCombine(nn.Module): if self.training: # random masking.. - mask_start = torch.randint(low=1, high=int(num_inputs / self.random_prob), - size=(num_frames,), device=scores.device).unsqueeze(1) + mask_start = torch.randint( + low=1, + high=int(num_inputs / self.random_prob), + size=(num_frames,), + device=scores.device, + ).unsqueeze(1) # mask will have rows like: [ False, False, False, True, True, .. ] - arange = torch.arange(num_inputs, device=scores.device).unsqueeze(0).expand( - num_frames, num_inputs) + arange = ( + torch.arange(num_inputs, device=scores.device) + .unsqueeze(0) + .expand(num_frames, num_inputs) + ) mask = arange >= mask_start - apply_single_prob = torch.logical_and(torch.rand(size=(num_frames, 1), - device=scores.device) < self.single_prob, - mask_start < num_inputs) - single_prob_mask = torch.logical_and(apply_single_prob, - arange < mask_start - 1) + apply_single_prob = torch.logical_and( + torch.rand(size=(num_frames, 1), device=scores.device) + < self.single_prob, + mask_start < num_inputs, + ) + single_prob_mask = torch.logical_and( + apply_single_prob, arange < mask_start - 1 + ) - mask = torch.logical_or(mask, - single_prob_mask) + mask = torch.logical_or(mask, single_prob_mask) - scores = scores.masked_fill(mask, float('-inf')) + scores = scores.masked_fill(mask, float("-inf")) if self.training and random.random() < 0.1: - scores = penalize_abs_values_gt(scores, - limit=10.0, - penalty=1.0e-04) + scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04) weights = scores.softmax(dim=1) @@ -1792,7 +1825,6 @@ class AttentionCombine(nn.Module): return ans - def _test_random_combine(): print("_test_random_combine()") num_inputs = 3 @@ -1801,8 +1833,8 @@ def _test_random_combine(): num_channels=num_channels, num_inputs=num_inputs, random_prob=0.5, - single_prob=0.0) - + single_prob=0.0, + ) x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)] @@ -1819,7 +1851,10 @@ def _test_zipformer_main(): # Just make sure the forward pass runs. c = Zipformer( - num_features=feature_dim, encoder_dims=(64,96), encoder_unmasked_dims=(48,64), nhead=(4,4) + num_features=feature_dim, + encoder_dims=(64, 96), + encoder_unmasked_dims=(48, 64), + nhead=(4, 4), ) batch_size = 5 seq_len = 20 @@ -1837,19 +1872,18 @@ def _test_zipformer_main(): ) f # to remove flake8 warnings + def _test_conv2d_subsampling(): num_features = 80 encoder_dims = 384 dropout = 0.1 - encoder_embed = Conv2dSubsampling(num_features, encoder_dims, - dropout=dropout) + encoder_embed = Conv2dSubsampling(num_features, encoder_dims, dropout=dropout) for i in range(20, 40): x = torch.rand(2, i, num_features) y = encoder_embed(x) assert (x.shape[1] - 7) // 2 == y.shape[1], (x.shape[1], y.shape[1]) - if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py index 9d7335e77..822f8e44b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py @@ -165,20 +165,24 @@ def get_parser(): "--avg", type=int, default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -273,8 +277,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -394,9 +397,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -455,10 +456,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -589,9 +587,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -624,8 +620,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -680,9 +675,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -719,13 +712,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -753,13 +745,12 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -788,7 +779,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -816,9 +807,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/export.py b/egs/librispeech/ASR/pruned_transducer_stateless8/export.py index 49f469e29..43eb0c1bc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/export.py @@ -129,20 +129,24 @@ def get_parser(): "--avg", type=int, default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), ) parser.add_argument( @@ -176,8 +180,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) add_model_arguments(parser) @@ -217,13 +220,12 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -252,13 +254,12 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -287,7 +288,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " + "Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -326,9 +327,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py index e79a3a3aa..ed920dc03 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py @@ -69,10 +69,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) return parser @@ -93,10 +95,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -267,9 +268,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/model.py b/egs/librispeech/ASR/pruned_transducer_stateless8/model.py index 497b89136..39a360796 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/model.py @@ -160,9 +160,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py index 373a48fc1..716136812 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py @@ -100,9 +100,11 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + help=( + "Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint()." + ), ) parser.add_argument( @@ -127,10 +129,12 @@ def get_parser(): "sound_files", type=str, nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + help=( + "The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz." + ), ) parser.add_argument( @@ -177,8 +181,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -209,10 +212,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -275,15 +277,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -355,9 +353,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py index 2603bb854..381a86a67 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py @@ -92,9 +92,7 @@ from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: @@ -132,7 +130,10 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--encoder-dims", type=str, default="384,384,384,384,384", - help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + help=( + "Embedding dimension in the 2 blocks of zipformer encoder layers, comma" + " separated" + ), ) parser.add_argument( @@ -147,9 +148,11 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--encoder-unmasked-dims", type=str, default="256,256,256,256,256", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " - " worse.", + help=( + "Unmasked dimensions in the encoders, relates to augmentation during" + " training. Must be <= each of encoder_dims. Empirically, less than 256" + " seems to make performance worse." + ), ) parser.add_argument( @@ -214,8 +217,7 @@ def get_parser(): "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. " - "Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.", ) parser.add_argument( @@ -285,42 +287,45 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", + help=( + "The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss" + ), ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", + help=( + "The scale to smooth the loss with lm (output of prediction network) part." + ), ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", + help=( + "To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss." + ), ) parser.add_argument( @@ -691,11 +696,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -744,9 +745,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -952,9 +951,7 @@ def train_one_epoch( # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or ( - cur_grad_scale < 8.0 and batch_idx % 400 == 0 - ): + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") @@ -975,11 +972,7 @@ def train_one_epoch( f"giga_tot_loss[{giga_tot_loss}], " f"batch size: {batch_size}, " f"lr: {cur_lr:.2e}, " - + ( - f"grad_scale: {scaler._scale.item()}" - if params.use_fp16 - else "" - ) + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") ) if tb_writer is not None: @@ -992,12 +985,8 @@ def train_one_epoch( f"train/current_{prefix}_", params.batch_idx_train, ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) libri_tot_loss.write_summary( tb_writer, "train/libri_tot_", params.batch_idx_train ) @@ -1011,10 +1000,7 @@ def train_one_epoch( params.batch_idx_train, ) - if ( - batch_idx % params.valid_interval == 0 - and not params.print_diagnostics - ): + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: logging.info("Computing validation loss") valid_info = compute_validation_loss( params=params, @@ -1026,7 +1012,8 @@ def train_one_epoch( model.train() logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + "Maximum memory allocated so far is" + f" {torch.cuda.max_memory_allocated()//1000000}MB" ) if tb_writer is not None: valid_info.write_summary( @@ -1054,8 +1041,7 @@ def filter_short_and_long_utterances( # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False @@ -1152,9 +1138,7 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) - optimizer = ScaledAdam( - model.parameters(), lr=params.base_lr, clipping_scale=2.0 - ) + optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) @@ -1172,7 +1156,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1207,9 +1191,7 @@ def run(rank, world_size, args): train_giga_cuts = train_giga_cuts.repeat(times=None) if args.enable_musan: - cuts_musan = load_manifest( - Path(args.manifest_dir) / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") else: cuts_musan = None @@ -1364,7 +1346,8 @@ def scan_pessimistic_batches_for_oom( display_and_save_batch(batch, params=params, sp=sp) raise logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + "Maximum memory allocated so far is" + f" {torch.cuda.max_memory_allocated()//1000000}MB" ) diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/README.md b/egs/librispeech/ASR/streaming_conformer_ctc/README.md index 01be7090b..53f383c99 100644 --- a/egs/librispeech/ASR/streaming_conformer_ctc/README.md +++ b/egs/librispeech/ASR/streaming_conformer_ctc/README.md @@ -1,20 +1,20 @@ ## Train and Decode -Commands of data preparation/train/decode steps are almost the same with +Commands of data preparation/train/decode steps are almost the same with ../conformer_ctc experiment except some options. Please read the code and understand following new added options before running this experiment: For data preparation: - + Nothing new. For streaming_conformer_ctc/train.py: - + --dynamic-chunk-training --short-chunk-proportion For streaming_conformer_ctc/streaming_decode.py: - + --chunk-size --tailing-num-frames --simulate-streaming @@ -57,10 +57,10 @@ And check md5sum values again. Finally, following files will be downloaded:

-streaming_models/  
-|-- lang_bpe  
-|   |-- L.pt  
-|   |-- Linv.pt  
+streaming_models/
+|-- lang_bpe
+|   |-- L.pt
+|   |-- Linv.pt
 |   |-- bpe.model
 |   |-- tokens.txt
 |   `-- words.txt
diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py b/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py
index ff4c91446..4f7427c1f 100644
--- a/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py
+++ b/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py
@@ -309,36 +309,26 @@ class Conformer(Transformer):
 
                 # start chunk_by_chunk decoding
                 offset = 0
-                for cur in range(
-                    0, num_frames - embed_left_context + 1, stride
-                ):
+                for cur in range(0, num_frames - embed_left_context + 1, stride):
                     end = min(cur + decoding_window, num_frames)
                     cur_feature = feature[:, cur:end, :]
                     cur_feature = self.encoder_embed(cur_feature)
-                    cur_embed, cur_pos_emb = self.encoder_pos(
-                        cur_feature, offset
-                    )
-                    cur_embed = cur_embed.permute(
-                        1, 0, 2
-                    )  # (B, T, F) -> (T, B, F)
+                    cur_embed, cur_pos_emb = self.encoder_pos(cur_feature, offset)
+                    cur_embed = cur_embed.permute(1, 0, 2)  # (B, T, F) -> (T, B, F)
 
                     cur_T = cur_feature.size(1)
                     if cur == 0:
                         # for first chunk extract the central pos embedding
-                        pos_emb_central = cur_pos_emb[
-                            0, (chunk_size - 1), :
-                        ].view(1, 1, -1)
+                        pos_emb_central = cur_pos_emb[0, (chunk_size - 1), :].view(
+                            1, 1, -1
+                        )
                         cur_T -= 1
                     pos_emb_positive.append(cur_pos_emb[0, :cur_T].flip(0))
                     pos_emb_negative.append(cur_pos_emb[0, -cur_T:])
                     assert pos_emb_positive[-1].size(0) == cur_T
 
-                    pos_emb_pos = torch.cat(pos_emb_positive, dim=0).unsqueeze(
-                        0
-                    )
-                    pos_emb_neg = torch.cat(pos_emb_negative, dim=0).unsqueeze(
-                        0
-                    )
+                    pos_emb_pos = torch.cat(pos_emb_positive, dim=0).unsqueeze(0)
+                    pos_emb_neg = torch.cat(pos_emb_negative, dim=0).unsqueeze(0)
                     cur_pos_emb = torch.cat(
                         [pos_emb_pos.flip(1), pos_emb_central, pos_emb_neg],
                         dim=1,
@@ -413,9 +403,7 @@ class ConformerEncoderLayer(nn.Module):
         causal: bool = False,
     ) -> None:
         super(ConformerEncoderLayer, self).__init__()
-        self.self_attn = RelPositionMultiheadAttention(
-            d_model, nhead, dropout=0.0
-        )
+        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
 
         self.feed_forward = nn.Sequential(
             nn.Linear(d_model, dim_feedforward),
@@ -431,22 +419,16 @@ class ConformerEncoderLayer(nn.Module):
             nn.Linear(dim_feedforward, d_model),
         )
 
-        self.conv_module = ConvolutionModule(
-            d_model, cnn_module_kernel, causal=causal
-        )
+        self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal)
 
-        self.norm_ff_macaron = nn.LayerNorm(
-            d_model
-        )  # for the macaron style FNN module
+        self.norm_ff_macaron = nn.LayerNorm(d_model)  # for the macaron style FNN module
         self.norm_ff = nn.LayerNorm(d_model)  # for the FNN module
         self.norm_mha = nn.LayerNorm(d_model)  # for the MHA module
 
         self.ff_scale = 0.5
 
         self.norm_conv = nn.LayerNorm(d_model)  # for the CNN module
-        self.norm_final = nn.LayerNorm(
-            d_model
-        )  # for the final output of the block
+        self.norm_final = nn.LayerNorm(d_model)  # for the final output of the block
 
         self.dropout = nn.Dropout(dropout)
 
@@ -480,9 +462,7 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(
-            self.feed_forward_macaron(src)
-        )
+        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -554,9 +534,7 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(
-            self.feed_forward_macaron(src)
-        )
+        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -736,9 +714,7 @@ class RelPositionalEncoding(torch.nn.Module):
 
     """
 
-    def __init__(
-        self, d_model: int, dropout_rate: float, max_len: int = 5000
-    ) -> None:
+    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
         """Construct an PositionalEncoding object."""
         super(RelPositionalEncoding, self).__init__()
         self.d_model = d_model
@@ -755,9 +731,7 @@ class RelPositionalEncoding(torch.nn.Module):
             # the length of self.pe is 2 * input_len - 1
             if self.pe.size(1) >= x_size_1 * 2 - 1:
                 # Note: TorchScript doesn't implement operator== for torch.Device
-                if self.pe.dtype != x.dtype or str(self.pe.device) != str(
-                    x.device
-                ):
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
                     self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                 return
         # Suppose `i` means to the position of query vector and `j` means the
@@ -783,9 +757,7 @@ class RelPositionalEncoding(torch.nn.Module):
         pe = torch.cat([pe_positive, pe_negative], dim=1)
         self.pe = pe.to(device=x.device, dtype=x.dtype)
 
-    def forward(
-        self, x: torch.Tensor, offset: int = 0
-    ) -> Tuple[Tensor, Tensor]:
+    def forward(self, x: torch.Tensor, offset: int = 0) -> Tuple[Tensor, Tensor]:
         """Add positional encoding.
 
         Args:
@@ -813,9 +785,7 @@ class RelPositionalEncoding(torch.nn.Module):
             pos_emb = torch.cat(
                 [
                     pos_emb[:, : (x_T - 1)],
-                    self.pe[0, self.pe.size(1) // 2].view(
-                        1, 1, self.pe.size(-1)
-                    ),
+                    self.pe[0, self.pe.size(1) // 2].view(1, 1, self.pe.size(-1)),
                     pos_emb[:, -(x_T - 1) :],  # noqa: E203
                 ],
                 dim=1,
@@ -1050,9 +1020,9 @@ class RelPositionMultiheadAttention(nn.Module):
 
         if torch.equal(query, key) and torch.equal(key, value):
             # self-attention
-            q, k, v = nn.functional.linear(
-                query, in_proj_weight, in_proj_bias
-            ).chunk(3, dim=-1)
+            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
+                3, dim=-1
+            )
 
         elif torch.equal(key, value):
             # encoder-decoder attention
@@ -1120,33 +1090,25 @@ class RelPositionMultiheadAttention(nn.Module):
             if attn_mask.dim() == 2:
                 attn_mask = attn_mask.unsqueeze(0)
                 if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
-                    raise RuntimeError(
-                        "The size of the 2D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
             elif attn_mask.dim() == 3:
                 if list(attn_mask.size()) != [
                     bsz * num_heads,
                     query.size(0),
                     key.size(0),
                 ]:
-                    raise RuntimeError(
-                        "The size of the 3D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
             else:
                 raise RuntimeError(
-                    "attn_mask's dimension {} is not supported".format(
-                        attn_mask.dim()
-                    )
+                    "attn_mask's dimension {} is not supported".format(attn_mask.dim())
                 )
             # attn_mask's dim is 3 now.
 
         # convert ByteTensor key_padding_mask to bool
-        if (
-            key_padding_mask is not None
-            and key_padding_mask.dtype == torch.uint8
-        ):
+        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
             warnings.warn(
-                "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
+                "Byte tensor for key_padding_mask is deprecated. Use bool tensor"
+                " instead."
             )
             key_padding_mask = key_padding_mask.to(torch.bool)
 
@@ -1185,24 +1147,16 @@ class RelPositionMultiheadAttention(nn.Module):
         # first compute matrix a and matrix c
         # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
         k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
-        matrix_ac = torch.matmul(
-            q_with_bias_u, k
-        )  # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
 
         # compute matrix b and matrix d
-        matrix_bd = torch.matmul(
-            q_with_bias_v, p
-        )  # (batch, head, time1, 2*time1-1)
-        matrix_bd = self.rel_shift(
-            matrix_bd, offset=offset
-        )  # [B, head, time1, time2]
+        matrix_bd = torch.matmul(q_with_bias_v, p)  # (batch, head, time1, 2*time1-1)
+        matrix_bd = self.rel_shift(matrix_bd, offset=offset)  # [B, head, time1, time2]
         attn_output_weights = (
             matrix_ac + matrix_bd
         ) * scaling  # (batch, head, time1, time2)
 
-        attn_output_weights = attn_output_weights.view(
-            bsz * num_heads, tgt_len, -1
-        )
+        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
 
         assert list(attn_output_weights.size()) == [
             bsz * num_heads,
@@ -1236,13 +1190,9 @@ class RelPositionMultiheadAttention(nn.Module):
         attn_output = torch.bmm(attn_output_weights, v)
         assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
         attn_output = (
-            attn_output.transpose(0, 1)
-            .contiguous()
-            .view(tgt_len, bsz, embed_dim)
-        )
-        attn_output = nn.functional.linear(
-            attn_output, out_proj_weight, out_proj_bias
+            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
         )
+        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
 
         if need_weights:
             # average attention weights over heads
diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py b/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py
index a74c51836..5a8149aad 100755
--- a/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py
+++ b/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py
@@ -28,6 +28,7 @@ import torch
 import torch.nn as nn
 from asr_datamodule import LibriSpeechAsrDataModule
 from conformer import Conformer
+
 from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
 from icefall.checkpoint import average_checkpoints, load_checkpoint
 from icefall.lexicon import Lexicon
@@ -62,32 +63,36 @@ def get_parser():
         "--epoch",
         type=int,
         default=34,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=20,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
         "--chunk-size",
         type=int,
         default=8,
-        help="Frames of right context"
-        "-1 for whole right context, i.e. non-streaming decoding",
+        help=(
+            "Frames of right context"
+            "-1 for whole right context, i.e. non-streaming decoding"
+        ),
     )
 
     parser.add_argument(
         "--tailing-num-frames",
         type=int,
         default=20,
-        help="tailing dummy frames padded to the right,"
-        "only used during decoding",
+        help="tailing dummy frames padded to the right,only used during decoding",
     )
 
     parser.add_argument(
@@ -139,8 +144,7 @@ def get_parser():
         "--avg-models",
         type=str,
         default=None,
-        help="Manually select models to average, seperated by comma;"
-        "e.g. 60,62,63,72",
+        help="Manually select models to average, seperated by comma;e.g. 60,62,63,72",
     )
 
     return parser
@@ -248,13 +252,9 @@ def decode_one_batch(
     maxlen = nnet_output.size(1)
     topk_prob, topk_index = nnet_output.topk(1, dim=2)  # (B, maxlen, 1)
     topk_index = topk_index.view(batch_size, maxlen)  # (B, maxlen)
-    topk_index = topk_index.masked_fill_(
-        memory_key_padding_mask, 0
-    )  # (B, maxlen)
+    topk_index = topk_index.masked_fill_(memory_key_padding_mask, 0)  # (B, maxlen)
     token_ids = [token_id.tolist() for token_id in topk_index]
-    token_ids = [
-        remove_duplicates_and_blank(token_id) for token_id in token_ids
-    ]
+    token_ids = [remove_duplicates_and_blank(token_id) for token_id in token_ids]
     hyps = bpe_model.decode(token_ids)
     hyps = [s.split() for s in hyps]
     return {key: hyps}
@@ -337,9 +337,7 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
 
     return results
 
@@ -357,15 +355,18 @@ def save_results(
     test_set_wers = dict()
     if params.avg_models is not None:
         avg_models = params.avg_models.replace(",", "_")
-        result_file_prefix = f"epoch-avg-{avg_models}-chunksize \
-        -{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-"
+        result_file_prefix = (
+            f"epoch-avg-{avg_models}-chunksize        "
+            f" -{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-"
+        )
     else:
-        result_file_prefix = f"epoch-{params.epoch}-avg-{params.avg}-chunksize \
-        -{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-"
+        result_file_prefix = (
+            f"epoch-{params.epoch}-avg-{params.avg}-chunksize        "
+            f" -{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-"
+        )
     for key, results in results_dict.items():
         recog_path = (
-            params.exp_dir
-            / f"{result_file_prefix}recogs-{test_set_name}-{key}.txt"
+            params.exp_dir / f"{result_file_prefix}recogs-{test_set_name}-{key}.txt"
         )
         store_transcripts(filename=recog_path, texts=results)
         if enable_log:
@@ -374,8 +375,7 @@ def save_results(
         # The following prints out WERs, per-word error statistics and aligned
         # ref/hyp pairs.
         errs_filename = (
-            params.exp_dir
-            / f"{result_file_prefix}-errs-{test_set_name}-{key}.txt"
+            params.exp_dir / f"{result_file_prefix}-errs-{test_set_name}-{key}.txt"
         )
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
@@ -384,9 +384,7 @@ def save_results(
             test_set_wers[key] = wer
 
         if enable_log:
-            logging.info(
-                "Wrote detailed error stats to {}".format(errs_filename)
-            )
+            logging.info("Wrote detailed error stats to {}".format(errs_filename))
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
@@ -474,9 +472,7 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save(
-            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
-        )
+        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
         return
 
     model.to(device)
@@ -507,9 +503,7 @@ def main():
             simulate_streaming=params.simulate_streaming,
         )
 
-        save_results(
-            params=params, test_set_name=test_set, results_dict=results_dict
-        )
+        save_results(params=params, test_set_name=test_set, results_dict=results_dict)
 
     logging.info("Done!")
 
diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/train.py b/egs/librispeech/ASR/streaming_conformer_ctc/train.py
index e41b7ea78..553b7d092 100755
--- a/egs/librispeech/ASR/streaming_conformer_ctc/train.py
+++ b/egs/librispeech/ASR/streaming_conformer_ctc/train.py
@@ -405,9 +405,7 @@ def compute_loss(
             #
             # See https://github.com/k2-fsa/icefall/issues/97
             # for more details
-            unsorted_token_ids = graph_compiler.texts_to_ids(
-                supervisions["text"]
-            )
+            unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"])
             att_loss = mmodel.decoder_forward(
                 encoder_memory,
                 memory_mask,
@@ -436,9 +434,7 @@ def compute_loss(
     info["utt_duration"] = supervisions["num_frames"].sum().item()
     # averaged padding proportion over utterances
     info["utt_pad_proportion"] = (
-        ((feature.size(1) - supervisions["num_frames"]) / feature.size(1))
-        .sum()
-        .item()
+        ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)).sum().item()
     )
 
     return loss, info
@@ -551,9 +547,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -668,9 +662,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py b/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py
index bc78e4a41..0c87fdf1b 100644
--- a/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py
+++ b/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py
@@ -149,9 +149,7 @@ class Transformer(nn.Module):
                 norm=decoder_norm,
             )
 
-            self.decoder_output_layer = torch.nn.Linear(
-                d_model, self.decoder_num_class
-            )
+            self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class)
 
             self.decoder_criterion = LabelSmoothingLoss()
         else:
@@ -286,23 +284,17 @@ class Transformer(nn.Module):
         """
         ys_in = add_sos(token_ids, sos_id=sos_id)
         ys_in = [torch.tensor(y) for y in ys_in]
-        ys_in_pad = pad_sequence(
-            ys_in, batch_first=True, padding_value=float(eos_id)
-        )
+        ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
 
         ys_out = add_eos(token_ids, eos_id=eos_id)
         ys_out = [torch.tensor(y) for y in ys_out]
-        ys_out_pad = pad_sequence(
-            ys_out, batch_first=True, padding_value=float(-1)
-        )
+        ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
 
         device = memory.device
         ys_in_pad = ys_in_pad.to(device)
         ys_out_pad = ys_out_pad.to(device)
 
-        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
-            device
-        )
+        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
 
         tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
         # TODO: Use length information to create the decoder padding mask
@@ -363,23 +355,17 @@ class Transformer(nn.Module):
 
         ys_in = add_sos(token_ids, sos_id=sos_id)
         ys_in = [torch.tensor(y) for y in ys_in]
-        ys_in_pad = pad_sequence(
-            ys_in, batch_first=True, padding_value=float(eos_id)
-        )
+        ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
 
         ys_out = add_eos(token_ids, eos_id=eos_id)
         ys_out = [torch.tensor(y) for y in ys_out]
-        ys_out_pad = pad_sequence(
-            ys_out, batch_first=True, padding_value=float(-1)
-        )
+        ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
 
         device = memory.device
         ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
         ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
 
-        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
-            device
-        )
+        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
 
         tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
         # TODO: Use length information to create the decoder padding mask
@@ -652,9 +638,7 @@ def _get_activation_fn(activation: str):
     elif activation == "gelu":
         return nn.functional.gelu
 
-    raise RuntimeError(
-        "activation should be relu/gelu, not {}".format(activation)
-    )
+    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
 
 
 class PositionalEncoding(nn.Module):
@@ -856,9 +840,7 @@ def encoder_padding_mask(
         1,
     ).to(torch.int32)
 
-    lengths = [
-        0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)
-    ]
+    lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)]
     for idx in range(supervision_segments.size(0)):
         # Note: TorchScript doesn't allow to unpack tensors as tuples
         sequence_idx = supervision_segments[idx, 0].item()
@@ -879,9 +861,7 @@ def encoder_padding_mask(
     return mask
 
 
-def decoder_padding_mask(
-    ys_pad: torch.Tensor, ignore_id: int = -1
-) -> torch.Tensor:
+def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor:
     """Generate a length mask for input.
 
     The masked position are filled with True,
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
index 355ccc99a..63afd6be2 100644
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
@@ -77,17 +77,18 @@ class LibriSpeechAsrDataModule:
     def add_arguments(cls, parser: argparse.ArgumentParser):
         group = parser.add_argument_group(
             title="ASR data related options",
-            description="These options are used for the preparation of "
-            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-            "effective batch sizes, sampling strategies, applied data "
-            "augmentations, etc.",
+            description=(
+                "These options are used for the preparation of "
+                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+                "effective batch sizes, sampling strategies, applied data "
+                "augmentations, etc."
+            ),
         )
         group.add_argument(
             "--full-libri",
             type=str2bool,
             default=True,
-            help="When enabled, use 960h LibriSpeech. "
-            "Otherwise, use 100h subset.",
+            help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.",
         )
         group.add_argument(
             "--manifest-dir",
@@ -99,59 +100,74 @@ class LibriSpeechAsrDataModule:
             "--max-duration",
             type=int,
             default=200.0,
-            help="Maximum pooled recordings duration (seconds) in a "
-            "single batch. You can reduce it if it causes CUDA OOM.",
+            help=(
+                "Maximum pooled recordings duration (seconds) in a "
+                "single batch. You can reduce it if it causes CUDA OOM."
+            ),
         )
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=True,
-            help="When enabled, the batches will come from buckets of "
-            "similar duration (saves padding frames).",
+            help=(
+                "When enabled, the batches will come from buckets of "
+                "similar duration (saves padding frames)."
+            ),
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=30,
-            help="The number of buckets for the DynamicBucketingSampler"
-            "(you might want to increase it for larger datasets).",
+            help=(
+                "The number of buckets for the DynamicBucketingSampler"
+                "(you might want to increase it for larger datasets)."
+            ),
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help="When enabled, utterances (cuts) will be concatenated "
-            "to minimize the amount of padding.",
+            help=(
+                "When enabled, utterances (cuts) will be concatenated "
+                "to minimize the amount of padding."
+            ),
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help="Determines the maximum duration of a concatenated cut "
-            "relative to the duration of the longest cut in a batch.",
+            help=(
+                "Determines the maximum duration of a concatenated cut "
+                "relative to the duration of the longest cut in a batch."
+            ),
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help="The amount of padding (in seconds) inserted between "
-            "concatenated cuts. This padding is filled with noise when "
-            "noise augmentation is used.",
+            help=(
+                "The amount of padding (in seconds) inserted between "
+                "concatenated cuts. This padding is filled with noise when "
+                "noise augmentation is used."
+            ),
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help="When enabled, use on-the-fly cut mixing and feature "
-            "extraction. Will drop existing precomputed feature manifests "
-            "if available.",
+            help=(
+                "When enabled, use on-the-fly cut mixing and feature "
+                "extraction. Will drop existing precomputed feature manifests "
+                "if available."
+            ),
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help="When enabled (=default), the examples will be "
-            "shuffled for each epoch.",
+            help=(
+                "When enabled (=default), the examples will be shuffled for each epoch."
+            ),
         )
         group.add_argument(
             "--drop-last",
@@ -163,17 +179,18 @@ class LibriSpeechAsrDataModule:
             "--return-cuts",
             type=str2bool,
             default=True,
-            help="When enabled, each batch will have the "
-            "field: batch['supervisions']['cut'] with the cuts that "
-            "were used to construct it.",
+            help=(
+                "When enabled, each batch will have the "
+                "field: batch['supervisions']['cut'] with the cuts that "
+                "were used to construct it."
+            ),
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that "
-            "collect the batches.",
+            help="The number of training dataloader workers that collect the batches.",
         )
 
         group.add_argument(
@@ -187,18 +204,22 @@ class LibriSpeechAsrDataModule:
             "--spec-aug-time-warp-factor",
             type=int,
             default=80,
-            help="Used only when --enable-spec-aug is True. "
-            "It specifies the factor for time warping in SpecAugment. "
-            "Larger values mean more warping. "
-            "A value less than 1 means to disable time warp.",
+            help=(
+                "Used only when --enable-spec-aug is True. "
+                "It specifies the factor for time warping in SpecAugment. "
+                "Larger values mean more warping. "
+                "A value less than 1 means to disable time warp."
+            ),
         )
 
         group.add_argument(
             "--enable-musan",
             type=str2bool,
             default=True,
-            help="When enabled, select noise from MUSAN and mix it"
-            "with training dataset. ",
+            help=(
+                "When enabled, select noise from MUSAN and mix it"
+                "with training dataset. "
+            ),
         )
 
         group.add_argument(
@@ -224,20 +245,16 @@ class LibriSpeechAsrDataModule:
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             logging.info("About to get Musan cuts")
-            cuts_musan = load_manifest(
-                self.args.manifest_dir / "musan_cuts.jsonl.gz"
-            )
+            cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
             transforms.append(
-                CutMix(
-                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
-                )
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
             )
         else:
             logging.info("Disable MUSAN")
 
         if self.args.concatenate_cuts:
             logging.info(
-                f"Using cut concatenation with duration factor "
+                "Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -252,9 +269,7 @@ class LibriSpeechAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -298,9 +313,7 @@ class LibriSpeechAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -356,9 +369,7 @@ class LibriSpeechAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
index 7d0cd0bf3..94ba0a4dc 100755
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
@@ -57,16 +57,19 @@ def get_parser():
         "--epoch",
         type=int,
         default=19,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=5,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
     parser.add_argument(
         "--method",
@@ -336,9 +339,7 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -400,9 +401,7 @@ def main():
 
     logging.info(f"device: {device}")
 
-    HLG = k2.Fsa.from_dict(
-        torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
-    )
+    HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
     HLG = HLG.to(device)
     assert HLG.requires_grad is False
 
@@ -467,9 +466,7 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save(
-            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
-        )
+        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
         return
 
     model.to(device)
@@ -498,9 +495,7 @@ def main():
             G=G,
         )
 
-        save_results(
-            params=params, test_set_name=test_set, results_dict=results_dict
-        )
+        save_results(params=params, test_set_name=test_set, results_dict=results_dict)
 
     logging.info("Done!")
 
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/model.py b/egs/librispeech/ASR/tdnn_lstm_ctc/model.py
index 5e04c11b4..1731e1ebe 100644
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/model.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/model.py
@@ -66,10 +66,7 @@ class TdnnLstm(nn.Module):
             nn.BatchNorm1d(num_features=500, affine=False),
         )
         self.lstms = nn.ModuleList(
-            [
-                nn.LSTM(input_size=500, hidden_size=500, num_layers=1)
-                for _ in range(5)
-            ]
+            [nn.LSTM(input_size=500, hidden_size=500, num_layers=1) for _ in range(5)]
         )
         self.lstm_bnorms = nn.ModuleList(
             [nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)]
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
index 2baeb6bba..722e8f003 100755
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
@@ -29,11 +29,7 @@ import torchaudio
 from model import TdnnLstm
 from torch.nn.utils.rnn import pad_sequence
 
-from icefall.decode import (
-    get_lattice,
-    one_best_decoding,
-    rescore_with_whole_lattice,
-)
+from icefall.decode import get_lattice, one_best_decoding, rescore_with_whole_lattice
 from icefall.utils import AttributeDict, get_texts
 
 
@@ -46,9 +42,11 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help="Path to the checkpoint. "
-        "The checkpoint is assumed to be saved by "
-        "icefall.checkpoint.save_checkpoint().",
+        help=(
+            "Path to the checkpoint. "
+            "The checkpoint is assumed to be saved by "
+            "icefall.checkpoint.save_checkpoint()."
+        ),
     )
 
     parser.add_argument(
@@ -58,9 +56,7 @@ def get_parser():
         help="Path to words.txt",
     )
 
-    parser.add_argument(
-        "--HLG", type=str, required=True, help="Path to HLG.pt."
-    )
+    parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
 
     parser.add_argument(
         "--method",
@@ -103,10 +99,12 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     return parser
@@ -144,10 +142,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -215,9 +212,7 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
     features = features.permute(0, 2, 1)  # now features is (N, C, T)
 
     with torch.no_grad():
@@ -269,9 +264,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py
index 6b37d5c23..071ac792b 100755
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py
@@ -355,9 +355,7 @@ def compute_loss(
     info["utt_duration"] = supervisions["num_frames"].sum().item()
     # averaged padding proportion over utterances
     info["utt_pad_proportion"] = (
-        ((feature.size(2) - supervisions["num_frames"]) / feature.size(2))
-        .sum()
-        .item()
+        ((feature.size(2) - supervisions["num_frames"]) / feature.size(2)).sum().item()
     )
 
     return loss, info
@@ -470,9 +468,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/egs/librispeech/ASR/transducer/beam_search.py b/egs/librispeech/ASR/transducer/beam_search.py
index 11032f31a..b45b6a9d8 100644
--- a/egs/librispeech/ASR/transducer/beam_search.py
+++ b/egs/librispeech/ASR/transducer/beam_search.py
@@ -38,9 +38,7 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
     blank_id = model.decoder.blank_id
     device = model.device
 
-    sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(
-        1, 1
-    )
+    sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(1, 1)
     decoder_out, (h, c) = model.decoder(sos)
     T = encoder_out.size(1)
     t = 0
@@ -123,9 +121,7 @@ def beam_search(
     max_u = 20000  # terminate after this number of steps
     u = 0
 
-    cache: Dict[
-        str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
-    ] = {}
+    cache: Dict[str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = {}
 
     while t < T and u < max_u:
         # fmt: off
@@ -157,9 +153,9 @@ def beam_search(
             cached_key = "_".join(map(str, y_star.ys))
 
             if cached_key not in cache:
-                decoder_input = torch.tensor(
-                    [y_star.ys[-1]], device=device
-                ).reshape(1, 1)
+                decoder_input = torch.tensor([y_star.ys[-1]], device=device).reshape(
+                    1, 1
+                )
 
                 decoder_out, decoder_state = model.decoder(
                     decoder_input,
diff --git a/egs/librispeech/ASR/transducer/decode.py b/egs/librispeech/ASR/transducer/decode.py
index 5f233df87..f30332cea 100755
--- a/egs/librispeech/ASR/transducer/decode.py
+++ b/egs/librispeech/ASR/transducer/decode.py
@@ -71,16 +71,19 @@ def get_parser():
         "--epoch",
         type=int,
         default=34,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=11,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -228,9 +231,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
     batch_size = encoder_out.size(0)
 
@@ -245,9 +246,7 @@ def decode_one_batch(
                 model=model, encoder_out=encoder_out_i, beam=params.beam_size
             )
         else:
-            raise ValueError(
-                f"Unsupported decoding method: {params.decoding_method}"
-            )
+            raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
         hyps.append(sp.decode(hyp).split())
 
     if params.decoding_method == "greedy_search":
@@ -318,9 +317,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -353,8 +350,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
diff --git a/egs/librispeech/ASR/transducer/export.py b/egs/librispeech/ASR/transducer/export.py
index 5a5db30c4..4d9f937f5 100755
--- a/egs/librispeech/ASR/transducer/export.py
+++ b/egs/librispeech/ASR/transducer/export.py
@@ -67,17 +67,20 @@ def get_parser():
         "--epoch",
         type=int,
         default=34,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=11,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -238,9 +241,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer/pretrained.py b/egs/librispeech/ASR/transducer/pretrained.py
index 1db2df648..7aadfbcd1 100755
--- a/egs/librispeech/ASR/transducer/pretrained.py
+++ b/egs/librispeech/ASR/transducer/pretrained.py
@@ -60,9 +60,11 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help="Path to the checkpoint. "
-        "The checkpoint is assumed to be saved by "
-        "icefall.checkpoint.save_checkpoint().",
+        help=(
+            "Path to the checkpoint. "
+            "The checkpoint is assumed to be saved by "
+            "icefall.checkpoint.save_checkpoint()."
+        ),
     )
 
     parser.add_argument(
@@ -87,10 +89,12 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     parser.add_argument(
@@ -188,10 +192,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -249,9 +252,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -287,9 +288,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer/rnn.py b/egs/librispeech/ASR/transducer/rnn.py
index 2a165b0c1..fe8732301 100644
--- a/egs/librispeech/ASR/transducer/rnn.py
+++ b/egs/librispeech/ASR/transducer/rnn.py
@@ -117,12 +117,8 @@ class LayerNormLSTMCell(nn.Module):
         )
 
         if bias:
-            self.bias_ih = nn.Parameter(
-                torch.empty(4 * hidden_size, **factory_kwargs)
-            )
-            self.bias_hh = nn.Parameter(
-                torch.empty(4 * hidden_size, **factory_kwargs)
-            )
+            self.bias_ih = nn.Parameter(torch.empty(4 * hidden_size, **factory_kwargs))
+            self.bias_hh = nn.Parameter(torch.empty(4 * hidden_size, **factory_kwargs))
         else:
             self.register_parameter("bias_ih", None)
             self.register_parameter("bias_hh", None)
@@ -348,9 +344,7 @@ class LayerNormLSTM(nn.Module):
             device=device,
             dtype=dtype,
         )
-        first_layer = LayerNormLSTMLayer(
-            input_size=input_size, **factory_kwargs
-        )
+        first_layer = LayerNormLSTMLayer(input_size=input_size, **factory_kwargs)
         layers = [first_layer]
         for i in range(1, num_layers):
             layers.append(
@@ -385,9 +379,7 @@ class LayerNormLSTM(nn.Module):
             - List[(next_h, next_c)] containing the hidden states for all layers
 
         """
-        output_states = torch.jit.annotate(
-            List[Tuple[torch.Tensor, torch.Tensor]], []
-        )
+        output_states = torch.jit.annotate(List[Tuple[torch.Tensor, torch.Tensor]], [])
         output = input
         for i, rnn_layer in enumerate(self.layers):
             state = states[i]
@@ -456,12 +448,8 @@ class LayerNormGRUCell(nn.Module):
         )
 
         if bias:
-            self.bias_ih = nn.Parameter(
-                torch.empty(3 * hidden_size, **factory_kwargs)
-            )
-            self.bias_hh = nn.Parameter(
-                torch.empty(3 * hidden_size, **factory_kwargs)
-            )
+            self.bias_ih = nn.Parameter(torch.empty(3 * hidden_size, **factory_kwargs))
+            self.bias_hh = nn.Parameter(torch.empty(3 * hidden_size, **factory_kwargs))
         else:
             self.register_parameter("bias_ih", None)
             self.register_parameter("bias_hh", None)
diff --git a/egs/librispeech/ASR/transducer/test_rnn.py b/egs/librispeech/ASR/transducer/test_rnn.py
index 8591e2d8a..74c94cc70 100755
--- a/egs/librispeech/ASR/transducer/test_rnn.py
+++ b/egs/librispeech/ASR/transducer/test_rnn.py
@@ -254,9 +254,7 @@ def test_layernorm_lstm_layer_with_projection_forward(device="cpu"):
         for name, self_param in self_layer.cell.named_parameters():
             getattr(torch_layer, f"{name}_l0").copy_(self_param)
 
-    torch_y, (torch_h, torch_c) = torch_layer(
-        x_clone, (h.unsqueeze(0), c.unsqueeze(0))
-    )
+    torch_y, (torch_h, torch_c) = torch_layer(x_clone, (h.unsqueeze(0), c.unsqueeze(0)))
     assert_allclose(self_y, torch_y)
     assert_allclose(self_h, torch_h)
     assert_allclose(self_c, torch_c)
@@ -303,9 +301,7 @@ def test_layernorm_lstm_layer_forward(device="cpu"):
         for name, self_param in self_layer.cell.named_parameters():
             getattr(torch_layer, f"{name}_l0").copy_(self_param)
 
-    torch_y, (torch_h, torch_c) = torch_layer(
-        x_clone, (h.unsqueeze(0), c.unsqueeze(0))
-    )
+    torch_y, (torch_h, torch_c) = torch_layer(x_clone, (h.unsqueeze(0), c.unsqueeze(0)))
     assert_allclose(self_y, torch_y)
     assert_allclose(self_h, torch_h)
     assert_allclose(self_c, torch_c)
@@ -594,9 +590,7 @@ def test_layernorm_gru_cell_forward(device="cpu"):
 
     assert_allclose(self_h, torch_h, atol=1e-5)
 
-    (
-        self_h.reshape(-1) * torch.arange(self_h.numel(), device=device)
-    ).sum().backward()
+    (self_h.reshape(-1) * torch.arange(self_h.numel(), device=device)).sum().backward()
     (
         torch_h.reshape(-1) * torch.arange(torch_h.numel(), device=device)
     ).sum().backward()
@@ -718,9 +712,7 @@ def test_layernorm_gru_forward(device="cpu"):
     T = torch.randint(low=2, high=100, size=(1,))
 
     x = torch.rand(N, T, input_size, device=device).requires_grad_()
-    states = [
-        torch.rand(N, hidden_size, device=device) for _ in range(num_layers)
-    ]
+    states = [torch.rand(N, hidden_size, device=device) for _ in range(num_layers)]
 
     x_clone = x.detach().clone().requires_grad_()
 
diff --git a/egs/librispeech/ASR/transducer/train.py b/egs/librispeech/ASR/transducer/train.py
index 1dd65eddb..674ea10a6 100755
--- a/egs/librispeech/ASR/transducer/train.py
+++ b/egs/librispeech/ASR/transducer/train.py
@@ -396,9 +396,7 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -520,9 +518,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -659,9 +655,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/transducer_lstm/beam_search.py b/egs/librispeech/ASR/transducer_lstm/beam_search.py
index 3531a9633..5342c3e8c 100644
--- a/egs/librispeech/ASR/transducer_lstm/beam_search.py
+++ b/egs/librispeech/ASR/transducer_lstm/beam_search.py
@@ -38,9 +38,7 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
     blank_id = model.decoder.blank_id
     device = model.device
 
-    sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(
-        1, 1
-    )
+    sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(1, 1)
     decoder_out, (h, c) = model.decoder(sos)
     T = encoder_out.size(1)
     t = 0
@@ -124,9 +122,7 @@ def beam_search(
     max_u = 20000  # terminate after this number of steps
     u = 0
 
-    cache: Dict[
-        str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
-    ] = {}
+    cache: Dict[str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = {}
 
     while t < T and u < max_u:
         # fmt: off
@@ -158,9 +154,9 @@ def beam_search(
             cached_key = "_".join(map(str, y_star.ys))
 
             if cached_key not in cache:
-                decoder_input = torch.tensor(
-                    [y_star.ys[-1]], device=device
-                ).reshape(1, 1)
+                decoder_input = torch.tensor([y_star.ys[-1]], device=device).reshape(
+                    1, 1
+                )
 
                 decoder_out, decoder_state = model.decoder(
                     decoder_input,
diff --git a/egs/librispeech/ASR/transducer_lstm/decode.py b/egs/librispeech/ASR/transducer_lstm/decode.py
index 604235e2a..61b9de504 100755
--- a/egs/librispeech/ASR/transducer_lstm/decode.py
+++ b/egs/librispeech/ASR/transducer_lstm/decode.py
@@ -71,16 +71,19 @@ def get_parser():
         "--epoch",
         type=int,
         default=77,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=55,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -225,9 +228,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
     batch_size = encoder_out.size(0)
 
@@ -242,9 +243,7 @@ def decode_one_batch(
                 model=model, encoder_out=encoder_out_i, beam=params.beam_size
             )
         else:
-            raise ValueError(
-                f"Unsupported decoding method: {params.decoding_method}"
-            )
+            raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
         hyps.append(sp.decode(hyp).split())
 
     if params.decoding_method == "greedy_search":
@@ -315,9 +314,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -350,8 +347,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
diff --git a/egs/librispeech/ASR/transducer_lstm/encoder.py b/egs/librispeech/ASR/transducer_lstm/encoder.py
index 3dc992dd2..038d80077 100644
--- a/egs/librispeech/ASR/transducer_lstm/encoder.py
+++ b/egs/librispeech/ASR/transducer_lstm/encoder.py
@@ -48,9 +48,7 @@ class LstmEncoder(EncoderInterface):
         if vgg_frontend:
             self.encoder_embed = VggSubsampling(num_features, real_hidden_size)
         else:
-            self.encoder_embed = Conv2dSubsampling(
-                num_features, real_hidden_size
-            )
+            self.encoder_embed = Conv2dSubsampling(num_features, real_hidden_size)
 
         self.rnn = nn.LSTM(
             input_size=hidden_size,
diff --git a/egs/librispeech/ASR/transducer_lstm/train.py b/egs/librispeech/ASR/transducer_lstm/train.py
index cdb801e79..57bda63fd 100755
--- a/egs/librispeech/ASR/transducer_lstm/train.py
+++ b/egs/librispeech/ASR/transducer_lstm/train.py
@@ -400,9 +400,7 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -524,9 +522,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -665,9 +661,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/transducer_stateless/alignment.py b/egs/librispeech/ASR/transducer_stateless/alignment.py
index f143611ea..65f2c58d8 100644
--- a/egs/librispeech/ASR/transducer_stateless/alignment.py
+++ b/egs/librispeech/ASR/transducer_stateless/alignment.py
@@ -193,9 +193,7 @@ def force_alignment(
         decoder_out = model.decoder(decoder_input, need_pad=False)
         # decoder_output is of shape (num_active_items, 1, decoder_output_dim)
 
-        current_encoder_out = current_encoder_out.expand(
-            decoder_out.size(0), 1, -1
-        )
+        current_encoder_out = current_encoder_out.expand(decoder_out.size(0), 1, -1)
 
         logits = model.joiner(
             current_encoder_out,
diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py
index ea985f30d..1d79eef9d 100644
--- a/egs/librispeech/ASR/transducer_stateless/beam_search.py
+++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py
@@ -316,9 +316,9 @@ def greedy_search(
         y = logits.argmax().item()
         if y != blank_id:
             hyp.append(y)
-            decoder_input = torch.tensor(
-                [hyp[-context_size:]], device=device
-            ).reshape(1, context_size)
+            decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape(
+                1, context_size
+            )
 
             decoder_out = model.decoder(decoder_input, need_pad=False)
 
@@ -478,9 +478,7 @@ class HypothesisList(object):
         key = hyp.key
         if key in self:
             old_hyp = self._data[key]  # shallow copy
-            torch.logaddexp(
-                old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
-            )
+            torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob)
         else:
             self._data[key] = hyp
 
@@ -496,9 +494,7 @@ class HypothesisList(object):
           Return the hypothesis that has the largest `log_prob`.
         """
         if length_norm:
-            return max(
-                self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)
-            )
+            return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys))
         else:
             return max(self._data.values(), key=lambda hyp: hyp.log_prob)
 
@@ -786,9 +782,7 @@ def modified_beam_search(
         log_probs_shape = k2.ragged.create_ragged_shape2(
             row_splits=row_splits, cached_tot_size=log_probs.numel()
         )
-        ragged_log_probs = k2.RaggedTensor(
-            shape=log_probs_shape, value=log_probs
-        )
+        ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
 
         for i in range(batch_size):
             topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
@@ -887,9 +881,7 @@ def _deprecated_modified_beam_search(
         decoder_out = model.decoder(decoder_input, need_pad=False)
         # decoder_output is of shape (num_hyps, 1, decoder_output_dim)
 
-        current_encoder_out = current_encoder_out.expand(
-            decoder_out.size(0), 1, -1
-        )
+        current_encoder_out = current_encoder_out.expand(decoder_out.size(0), 1, -1)
 
         logits = model.joiner(
             current_encoder_out,
@@ -959,9 +951,9 @@ def beam_search(
 
     device = model.device
 
-    decoder_input = torch.tensor(
-        [blank_id] * context_size, device=device
-    ).reshape(1, context_size)
+    decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
+        1, context_size
+    )
 
     decoder_out = model.decoder(decoder_input, need_pad=False)
 
diff --git a/egs/librispeech/ASR/transducer_stateless/compute_ali.py b/egs/librispeech/ASR/transducer_stateless/compute_ali.py
index 48769e9d1..89992856d 100755
--- a/egs/librispeech/ASR/transducer_stateless/compute_ali.py
+++ b/egs/librispeech/ASR/transducer_stateless/compute_ali.py
@@ -54,16 +54,19 @@ def get_parser():
         "--epoch",
         type=int,
         default=34,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=20,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -124,8 +127,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
@@ -162,9 +164,7 @@ def compute_alignments(
 
         feature_lens = supervisions["num_frames"].to(device)
 
-        encoder_out, encoder_out_lens = model.encoder(
-            x=feature, x_lens=feature_lens
-        )
+        encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
         batch_size = encoder_out.size(0)
 
@@ -204,9 +204,7 @@ def compute_alignments(
         if batch_idx % 2 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
 
     return CutSet.from_cuts(cuts)
 
diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py
index cde52c9fc..d279eae85 100644
--- a/egs/librispeech/ASR/transducer_stateless/conformer.py
+++ b/egs/librispeech/ASR/transducer_stateless/conformer.py
@@ -209,10 +209,7 @@ class Conformer(Transformer):
 
           NOTE: the returned tensors are on the given device.
         """
-        if (
-            len(self._init_state) == 2
-            and self._init_state[0].size(1) == left_context
-        ):
+        if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context:
             # Note: It is OK to share the init state as it is
             # not going to be modified by the model
             return self._init_state
@@ -421,9 +418,7 @@ class ConformerEncoderLayer(nn.Module):
         causal: bool = False,
     ) -> None:
         super(ConformerEncoderLayer, self).__init__()
-        self.self_attn = RelPositionMultiheadAttention(
-            d_model, nhead, dropout=0.0
-        )
+        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
 
         self.feed_forward = nn.Sequential(
             nn.Linear(d_model, dim_feedforward),
@@ -439,22 +434,16 @@ class ConformerEncoderLayer(nn.Module):
             nn.Linear(dim_feedforward, d_model),
         )
 
-        self.conv_module = ConvolutionModule(
-            d_model, cnn_module_kernel, causal=causal
-        )
+        self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal)
 
-        self.norm_ff_macaron = nn.LayerNorm(
-            d_model
-        )  # for the macaron style FNN module
+        self.norm_ff_macaron = nn.LayerNorm(d_model)  # for the macaron style FNN module
         self.norm_ff = nn.LayerNorm(d_model)  # for the FNN module
         self.norm_mha = nn.LayerNorm(d_model)  # for the MHA module
 
         self.ff_scale = 0.5
 
         self.norm_conv = nn.LayerNorm(d_model)  # for the CNN module
-        self.norm_final = nn.LayerNorm(
-            d_model
-        )  # for the final output of the block
+        self.norm_final = nn.LayerNorm(d_model)  # for the final output of the block
 
         self.dropout = nn.Dropout(dropout)
 
@@ -486,9 +475,7 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(
-            self.feed_forward_macaron(src)
-        )
+        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -514,9 +501,7 @@ class ConformerEncoderLayer(nn.Module):
         if self.normalize_before:
             src = self.norm_conv(src)
 
-        src, _ = self.conv_module(
-            src, src_key_padding_mask=src_key_padding_mask
-        )
+        src, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
         src = residual + self.dropout(src)
 
         if not self.normalize_before:
@@ -581,9 +566,7 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(
-            self.feed_forward_macaron(src)
-        )
+        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -625,9 +608,7 @@ class ConformerEncoderLayer(nn.Module):
         if self.normalize_before:
             src = self.norm_conv(src)
 
-        src, conv_cache = self.conv_module(
-            src, states[1], right_context=right_context
-        )
+        src, conv_cache = self.conv_module(src, states[1], right_context=right_context)
         states[1] = conv_cache
         src = residual + self.dropout(src)
 
@@ -779,9 +760,7 @@ class RelPositionalEncoding(torch.nn.Module):
 
     """
 
-    def __init__(
-        self, d_model: int, dropout_rate: float, max_len: int = 5000
-    ) -> None:
+    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
         """Construct an PositionalEncoding object."""
         super(RelPositionalEncoding, self).__init__()
         self.d_model = d_model
@@ -798,9 +777,7 @@ class RelPositionalEncoding(torch.nn.Module):
             # the length of self.pe is 2 * input_len - 1
             if self.pe.size(1) >= x_size_1 * 2 - 1:
                 # Note: TorchScript doesn't implement operator== for torch.Device
-                if self.pe.dtype != x.dtype or str(self.pe.device) != str(
-                    x.device
-                ):
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
                     self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                 return
         # Suppose `i` means to the position of query vector and `j` means the
@@ -826,9 +803,7 @@ class RelPositionalEncoding(torch.nn.Module):
         pe = torch.cat([pe_positive, pe_negative], dim=1)
         self.pe = pe.to(device=x.device, dtype=x.dtype)
 
-    def forward(
-        self, x: torch.Tensor, left_context: int = 0
-    ) -> Tuple[Tensor, Tensor]:
+    def forward(self, x: torch.Tensor, left_context: int = 0) -> Tuple[Tensor, Tensor]:
         """Add positional encoding.
 
         Args:
@@ -1092,9 +1067,9 @@ class RelPositionMultiheadAttention(nn.Module):
 
         if torch.equal(query, key) and torch.equal(key, value):
             # self-attention
-            q, k, v = nn.functional.linear(
-                query, in_proj_weight, in_proj_bias
-            ).chunk(3, dim=-1)
+            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
+                3, dim=-1
+            )
 
         elif torch.equal(key, value):
             # encoder-decoder attention
@@ -1163,33 +1138,25 @@ class RelPositionMultiheadAttention(nn.Module):
             if attn_mask.dim() == 2:
                 attn_mask = attn_mask.unsqueeze(0)
                 if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
-                    raise RuntimeError(
-                        "The size of the 2D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
             elif attn_mask.dim() == 3:
                 if list(attn_mask.size()) != [
                     bsz * num_heads,
                     query.size(0),
                     key.size(0),
                 ]:
-                    raise RuntimeError(
-                        "The size of the 3D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
             else:
                 raise RuntimeError(
-                    "attn_mask's dimension {} is not supported".format(
-                        attn_mask.dim()
-                    )
+                    "attn_mask's dimension {} is not supported".format(attn_mask.dim())
                 )
             # attn_mask's dim is 3 now.
 
         # convert ByteTensor key_padding_mask to bool
-        if (
-            key_padding_mask is not None
-            and key_padding_mask.dtype == torch.uint8
-        ):
+        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
             warnings.warn(
-                "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
+                "Byte tensor for key_padding_mask is deprecated. Use bool tensor"
+                " instead."
             )
             key_padding_mask = key_padding_mask.to(torch.bool)
 
@@ -1228,14 +1195,10 @@ class RelPositionMultiheadAttention(nn.Module):
         # first compute matrix a and matrix c
         # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
         k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
-        matrix_ac = torch.matmul(
-            q_with_bias_u, k
-        )  # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
 
         # compute matrix b and matrix d
-        matrix_bd = torch.matmul(
-            q_with_bias_v, p
-        )  # (batch, head, time1, 2*time1-1)
+        matrix_bd = torch.matmul(q_with_bias_v, p)  # (batch, head, time1, 2*time1-1)
 
         matrix_bd = self.rel_shift(matrix_bd, left_context=left_context)
 
@@ -1243,9 +1206,7 @@ class RelPositionMultiheadAttention(nn.Module):
             matrix_ac + matrix_bd
         ) * scaling  # (batch, head, time1, time2)
 
-        attn_output_weights = attn_output_weights.view(
-            bsz * num_heads, tgt_len, -1
-        )
+        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
 
         assert list(attn_output_weights.size()) == [
             bsz * num_heads,
@@ -1290,9 +1251,7 @@ class RelPositionMultiheadAttention(nn.Module):
             attn_output_weights = attn_output_weights.view(
                 bsz, num_heads, tgt_len, src_len
             )
-            attn_output_weights = attn_output_weights.masked_fill(
-                combined_mask, 0.0
-            )
+            attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0)
             attn_output_weights = attn_output_weights.view(
                 bsz * num_heads, tgt_len, src_len
             )
@@ -1304,13 +1263,9 @@ class RelPositionMultiheadAttention(nn.Module):
         attn_output = torch.bmm(attn_output_weights, v)
         assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
         attn_output = (
-            attn_output.transpose(0, 1)
-            .contiguous()
-            .view(tgt_len, bsz, embed_dim)
-        )
-        attn_output = nn.functional.linear(
-            attn_output, out_proj_weight, out_proj_bias
+            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
         )
+        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
 
         if need_weights:
             # average attention weights over heads
@@ -1418,16 +1373,12 @@ class ConvolutionModule(nn.Module):
                 # manualy padding self.lorder zeros to the left
                 x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
             else:
-                assert (
-                    not self.training
-                ), "Cache should be None in training time"
+                assert not self.training, "Cache should be None in training time"
                 assert cache.size(0) == self.lorder
                 x = torch.cat([cache.permute(1, 2, 0), x], dim=2)
                 if right_context > 0:
                     cache = x.permute(2, 0, 1)[
-                        -(self.lorder + right_context) : (  # noqa
-                            -right_context
-                        ),
+                        -(self.lorder + right_context) : (-right_context),  # noqa
                         ...,
                     ]
                 else:
diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py
index 74bba9cad..314f49154 100755
--- a/egs/librispeech/ASR/transducer_stateless/decode.py
+++ b/egs/librispeech/ASR/transducer_stateless/decode.py
@@ -94,16 +94,19 @@ def get_parser():
         "--epoch",
         type=int,
         default=29,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=13,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -171,8 +174,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -230,9 +232,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
     hyps = []
 
@@ -248,10 +248,7 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -297,11 +294,7 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            (
-                f"beam_{params.beam}_"
-                f"max_contexts_{params.max_contexts}_"
-                f"max_states_{params.max_states}"
-            ): hyps
+            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -374,9 +367,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -409,8 +400,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -450,9 +440,7 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += (
-            f"-{params.decoding_method}-beam-size-{params.beam_size}"
-        )
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
diff --git a/egs/librispeech/ASR/transducer_stateless/decoder.py b/egs/librispeech/ASR/transducer_stateless/decoder.py
index fbc2373a9..a182d91e2 100644
--- a/egs/librispeech/ASR/transducer_stateless/decoder.py
+++ b/egs/librispeech/ASR/transducer_stateless/decoder.py
@@ -87,9 +87,7 @@ class Decoder(nn.Module):
         if self.context_size > 1:
             embedding_out = embedding_out.permute(0, 2, 1)
             if need_pad is True:
-                embedding_out = F.pad(
-                    embedding_out, pad=(self.context_size - 1, 0)
-                )
+                embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
             else:
                 # During inference time, there is no need to do extra padding
                 # as we only need one output
diff --git a/egs/librispeech/ASR/transducer_stateless/export.py b/egs/librispeech/ASR/transducer_stateless/export.py
index 8bd0bdea1..7c10b4348 100755
--- a/egs/librispeech/ASR/transducer_stateless/export.py
+++ b/egs/librispeech/ASR/transducer_stateless/export.py
@@ -68,17 +68,20 @@ def get_parser():
         "--epoch",
         type=int,
         default=20,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=10,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -109,8 +112,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
@@ -244,9 +246,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless/joiner.py b/egs/librispeech/ASR/transducer_stateless/joiner.py
index 93cccbd8c..e1625992d 100644
--- a/egs/librispeech/ASR/transducer_stateless/joiner.py
+++ b/egs/librispeech/ASR/transducer_stateless/joiner.py
@@ -60,13 +60,9 @@ class Joiner(nn.Module):
         encoder_out_len: List[int] = encoder_out_len.tolist()
         decoder_out_len: List[int] = decoder_out_len.tolist()
 
-        encoder_out_list = [
-            encoder_out[i, : encoder_out_len[i], :] for i in range(N)
-        ]
+        encoder_out_list = [encoder_out[i, : encoder_out_len[i], :] for i in range(N)]
 
-        decoder_out_list = [
-            decoder_out[i, : decoder_out_len[i], :] for i in range(N)
-        ]
+        decoder_out_list = [decoder_out[i, : decoder_out_len[i], :] for i in range(N)]
 
         x = [
             e.unsqueeze(1) + d.unsqueeze(0)
diff --git a/egs/librispeech/ASR/transducer_stateless/pretrained.py b/egs/librispeech/ASR/transducer_stateless/pretrained.py
index b64521801..bd7eeff28 100755
--- a/egs/librispeech/ASR/transducer_stateless/pretrained.py
+++ b/egs/librispeech/ASR/transducer_stateless/pretrained.py
@@ -90,9 +90,11 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help="Path to the checkpoint. "
-        "The checkpoint is assumed to be saved by "
-        "icefall.checkpoint.save_checkpoint().",
+        help=(
+            "Path to the checkpoint. "
+            "The checkpoint is assumed to be saved by "
+            "icefall.checkpoint.save_checkpoint()."
+        ),
     )
 
     parser.add_argument(
@@ -117,10 +119,12 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     parser.add_argument(
@@ -167,8 +171,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -197,10 +200,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -259,9 +261,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -334,9 +334,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py b/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py
index b00fc34f1..9af46846a 100755
--- a/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py
+++ b/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py
@@ -140,16 +140,13 @@ def main():
                 token_alignment[i, : token_alignment_length[i]].tolist(), sp=sp
             )
             word_starting_time = [
-                "{:.2f}".format(i * frame_shift_in_second)
-                for i in word_starting_frames
+                "{:.2f}".format(i * frame_shift_in_second) for i in word_starting_frames
             ]
 
             words = supervisions["text"][i].split()
 
             assert len(word_starting_frames) == len(words)
-            word_starting_time_dict[cuts[i].id] = list(
-                zip(words, word_starting_time)
-            )
+            word_starting_time_dict[cuts[i].id] = list(zip(words, word_starting_time))
 
         # This is a demo script and we exit here after processing
         # one batch.
@@ -160,9 +157,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless/test_conformer.py b/egs/librispeech/ASR/transducer_stateless/test_conformer.py
index d1350c8ab..65b08d425 100755
--- a/egs/librispeech/ASR/transducer_stateless/test_conformer.py
+++ b/egs/librispeech/ASR/transducer_stateless/test_conformer.py
@@ -29,9 +29,7 @@ from conformer import Conformer
 
 def test_conformer():
     feature_dim = 50
-    c = Conformer(
-        num_features=feature_dim, output_dim=256, d_model=128, nhead=4
-    )
+    c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4)
     batch_size = 5
     seq_len = 20
     # Just make sure the forward pass runs.
diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py
index ae93f3348..bcb883fa5 100755
--- a/egs/librispeech/ASR/transducer_stateless/train.py
+++ b/egs/librispeech/ASR/transducer_stateless/train.py
@@ -136,8 +136,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -422,9 +421,7 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -545,9 +542,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -664,13 +659,9 @@ def run(rank, world_size, args):
         num_removed = num_in_total - num_left
         removed_percent = num_removed / num_in_total * 100
 
-        logging.info(
-            f"Before removing short and long utterances: {num_in_total}"
-        )
+        logging.info(f"Before removing short and long utterances: {num_in_total}")
         logging.info(f"After removing short and long utterances: {num_left}")
-        logging.info(
-            f"Removed {num_removed} utterances ({removed_percent:.5f}%)"
-        )
+        logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
     except TypeError as e:
         # You can ignore this error as previous versions of Lhotse work fine
         # for the above code. In recent versions of Lhotse, it uses
@@ -698,9 +689,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/transducer_stateless/transformer.py b/egs/librispeech/ASR/transducer_stateless/transformer.py
index e851dcc32..b3ff153c1 100644
--- a/egs/librispeech/ASR/transducer_stateless/transformer.py
+++ b/egs/librispeech/ASR/transducer_stateless/transformer.py
@@ -250,9 +250,7 @@ def _get_activation_fn(activation: str):
     elif activation == "gelu":
         return nn.functional.gelu
 
-    raise RuntimeError(
-        "activation should be relu/gelu, not {}".format(activation)
-    )
+    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
 
 
 class PositionalEncoding(nn.Module):
diff --git a/egs/librispeech/ASR/transducer_stateless2/decode.py b/egs/librispeech/ASR/transducer_stateless2/decode.py
index ac2807241..86ef9e5b6 100755
--- a/egs/librispeech/ASR/transducer_stateless2/decode.py
+++ b/egs/librispeech/ASR/transducer_stateless2/decode.py
@@ -94,16 +94,19 @@ def get_parser():
         "--epoch",
         type=int,
         default=29,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=13,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -171,8 +174,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -230,9 +232,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
     hyps = []
 
@@ -248,10 +248,7 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -297,11 +294,7 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            (
-                f"beam_{params.beam}_"
-                f"max_contexts_{params.max_contexts}_"
-                f"max_states_{params.max_states}"
-            ): hyps
+            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -374,9 +367,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -409,8 +400,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -450,9 +440,7 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += (
-            f"-{params.decoding_method}-beam-size-{params.beam_size}"
-        )
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
diff --git a/egs/librispeech/ASR/transducer_stateless2/export.py b/egs/librispeech/ASR/transducer_stateless2/export.py
index 57c1a6094..d95eeb1f4 100755
--- a/egs/librispeech/ASR/transducer_stateless2/export.py
+++ b/egs/librispeech/ASR/transducer_stateless2/export.py
@@ -63,17 +63,20 @@ def get_parser():
         "--epoch",
         type=int,
         default=20,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=10,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -104,8 +107,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
@@ -176,9 +178,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless2/pretrained.py b/egs/librispeech/ASR/transducer_stateless2/pretrained.py
index 292f77f03..793931e3b 100755
--- a/egs/librispeech/ASR/transducer_stateless2/pretrained.py
+++ b/egs/librispeech/ASR/transducer_stateless2/pretrained.py
@@ -90,9 +90,11 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help="Path to the checkpoint. "
-        "The checkpoint is assumed to be saved by "
-        "icefall.checkpoint.save_checkpoint().",
+        help=(
+            "Path to the checkpoint. "
+            "The checkpoint is assumed to be saved by "
+            "icefall.checkpoint.save_checkpoint()."
+        ),
     )
 
     parser.add_argument(
@@ -117,10 +119,12 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     parser.add_argument(
@@ -167,8 +171,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -197,10 +200,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -259,9 +261,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -334,9 +334,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless2/train.py b/egs/librispeech/ASR/transducer_stateless2/train.py
index ea15c9040..68e247f23 100755
--- a/egs/librispeech/ASR/transducer_stateless2/train.py
+++ b/egs/librispeech/ASR/transducer_stateless2/train.py
@@ -136,8 +136,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -410,9 +409,7 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -533,9 +530,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -652,13 +647,9 @@ def run(rank, world_size, args):
         num_removed = num_in_total - num_left
         removed_percent = num_removed / num_in_total * 100
 
-        logging.info(
-            f"Before removing short and long utterances: {num_in_total}"
-        )
+        logging.info(f"Before removing short and long utterances: {num_in_total}")
         logging.info(f"After removing short and long utterances: {num_left}")
-        logging.info(
-            f"Removed {num_removed} utterances ({removed_percent:.5f}%)"
-        )
+        logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
     except TypeError as e:
         # You can ignore this error as previous versions of Lhotse work fine
         # for the above code. In recent versions of Lhotse, it uses
@@ -686,9 +677,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py
index d596e05cb..22b6ab911 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py
@@ -95,16 +95,19 @@ def get_parser():
         "--epoch",
         type=int,
         default=29,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=13,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -172,8 +175,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -231,9 +233,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
     hyps = []
 
@@ -249,10 +249,7 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -298,11 +295,7 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            (
-                f"beam_{params.beam}_"
-                f"max_contexts_{params.max_contexts}_"
-                f"max_states_{params.max_states}"
-            ): hyps
+            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -375,9 +368,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -410,8 +401,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -451,9 +441,7 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += (
-            f"-{params.decoding_method}-beam-size-{params.beam_size}"
-        )
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py
index b6b69d932..fad9a6977 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py
@@ -69,17 +69,20 @@ def get_parser():
         "--epoch",
         type=int,
         default=20,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=10,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -110,8 +113,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
@@ -247,9 +249,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py
index f297fa2b2..efd257b5d 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py
@@ -90,9 +90,11 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help="Path to the checkpoint. "
-        "The checkpoint is assumed to be saved by "
-        "icefall.checkpoint.save_checkpoint().",
+        help=(
+            "Path to the checkpoint. "
+            "The checkpoint is assumed to be saved by "
+            "icefall.checkpoint.save_checkpoint()."
+        ),
     )
 
     parser.add_argument(
@@ -117,10 +119,12 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     parser.add_argument(
@@ -167,8 +171,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -197,10 +200,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -259,9 +261,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -334,9 +334,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py
index ef51a7811..1e1188ca6 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py
@@ -41,9 +41,7 @@ def test_dataset():
     print(args)
 
     if args.enable_musan:
-        cuts_musan = load_manifest(
-            Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
-        )
+        cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz")
     else:
         cuts_musan = None
 
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py
index 27912738c..88987d91c 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py
@@ -114,8 +114,7 @@ def get_parser():
         "--full-libri",
         type=str2bool,
         default=True,
-        help="When enabled, use 960h LibriSpeech. "
-        "Otherwise, use 100h subset.",
+        help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.",
     )
 
     parser.add_argument(
@@ -170,8 +169,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -469,9 +467,7 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -635,9 +631,7 @@ def train_one_epoch(
                     f"train/current_{prefix}_",
                     params.batch_idx_train,
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
                 libri_tot_loss.write_summary(
                     tb_writer, "train/libri_tot_", params.batch_idx_train
                 )
@@ -784,9 +778,7 @@ def run(rank, world_size, args):
     train_giga_cuts = train_giga_cuts.repeat(times=None)
 
     if args.enable_musan:
-        cuts_musan = load_manifest(
-            Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
-        )
+        cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz")
     else:
         cuts_musan = None
 
@@ -825,9 +817,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/ptb/LM/local/sort_lm_training_data.py b/egs/ptb/LM/local/sort_lm_training_data.py
index af54dbd07..bed3856e4 100755
--- a/egs/ptb/LM/local/sort_lm_training_data.py
+++ b/egs/ptb/LM/local/sort_lm_training_data.py
@@ -135,9 +135,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/ptb/LM/local/test_prepare_lm_training_data.py b/egs/ptb/LM/local/test_prepare_lm_training_data.py
index 877720e7b..3790045fa 100755
--- a/egs/ptb/LM/local/test_prepare_lm_training_data.py
+++ b/egs/ptb/LM/local/test_prepare_lm_training_data.py
@@ -54,9 +54,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/spgispeech/ASR/local/compute_fbank_musan.py b/egs/spgispeech/ASR/local/compute_fbank_musan.py
index 6cb8b65ae..9bea28a41 100755
--- a/egs/spgispeech/ASR/local/compute_fbank_musan.py
+++ b/egs/spgispeech/ASR/local/compute_fbank_musan.py
@@ -87,9 +87,7 @@ def compute_fbank_musan():
     # create chunks of Musan with duration 5 - 10 seconds
     musan_cuts = (
         CutSet.from_manifests(
-            recordings=combine(
-                part["recordings"] for part in manifests.values()
-            )
+            recordings=combine(part["recordings"] for part in manifests.values())
         )
         .cut_into_windows(10.0)
         .filter(lambda c: c.duration > 5)
@@ -108,8 +106,6 @@ def compute_fbank_musan():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
     logging.basicConfig(format=formatter, level=logging.INFO)
     compute_fbank_musan()
diff --git a/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py b/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py
index 8116e7605..20ff6d7ab 100755
--- a/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py
+++ b/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py
@@ -103,11 +103,7 @@ def compute_fbank_spgispeech(args):
             chunk_size=chunk_size,
         )
         start = args.start
-        stop = (
-            min(args.stop, args.num_splits)
-            if args.stop > 0
-            else args.num_splits
-        )
+        stop = min(args.stop, args.num_splits) if args.stop > 0 else args.num_splits
         num_digits = len(str(args.num_splits))
         for i in range(start, stop):
             idx = f"{i + 1}".zfill(num_digits)
@@ -129,9 +125,7 @@ def compute_fbank_spgispeech(args):
                 logging.info(f"{partition} already exists - skipping.")
                 continue
             logging.info(f"Processing {partition}")
-            cut_set = load_manifest_lazy(
-                src_dir / f"cuts_{partition}_raw.jsonl.gz"
-            )
+            cut_set = load_manifest_lazy(src_dir / f"cuts_{partition}_raw.jsonl.gz")
             cut_set = cut_set.compute_and_store_features_batch(
                 extractor=extractor,
                 storage_path=output_dir / f"feats_{partition}",
@@ -144,9 +138,7 @@ def compute_fbank_spgispeech(args):
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
     logging.basicConfig(format=formatter, level=logging.INFO)
 
     args = get_args()
diff --git a/egs/spgispeech/ASR/local/prepare_splits.py b/egs/spgispeech/ASR/local/prepare_splits.py
index 8c8f1c133..508d4acd8 100755
--- a/egs/spgispeech/ASR/local/prepare_splits.py
+++ b/egs/spgispeech/ASR/local/prepare_splits.py
@@ -55,9 +55,7 @@ def split_spgispeech_train():
 
     # Add speed perturbation
     train_cuts = (
-        train_cuts
-        + train_cuts.perturb_speed(0.9)
-        + train_cuts.perturb_speed(1.1)
+        train_cuts + train_cuts.perturb_speed(0.9) + train_cuts.perturb_speed(1.1)
     )
 
     # Write the manifests to disk.
@@ -73,9 +71,7 @@ def split_spgispeech_train():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
     logging.basicConfig(format=formatter, level=logging.INFO)
 
     split_spgispeech_train()
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
index f165f6e60..83f95d123 100644
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -70,10 +70,12 @@ class SPGISpeechAsrDataModule:
     def add_arguments(cls, parser: argparse.ArgumentParser):
         group = parser.add_argument_group(
             title="ASR data related options",
-            description="These options are used for the preparation of "
-            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-            "effective batch sizes, sampling strategies, applied data "
-            "augmentations, etc.",
+            description=(
+                "These options are used for the preparation of "
+                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+                "effective batch sizes, sampling strategies, applied data "
+                "augmentations, etc."
+            ),
         )
         group.add_argument(
             "--manifest-dir",
@@ -85,67 +87,81 @@ class SPGISpeechAsrDataModule:
             "--enable-musan",
             type=str2bool,
             default=True,
-            help="When enabled, select noise from MUSAN and mix it "
-            "with training dataset. ",
+            help=(
+                "When enabled, select noise from MUSAN and mix it "
+                "with training dataset. "
+            ),
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help="When enabled, utterances (cuts) will be concatenated "
-            "to minimize the amount of padding.",
+            help=(
+                "When enabled, utterances (cuts) will be concatenated "
+                "to minimize the amount of padding."
+            ),
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help="Determines the maximum duration of a concatenated cut "
-            "relative to the duration of the longest cut in a batch.",
+            help=(
+                "Determines the maximum duration of a concatenated cut "
+                "relative to the duration of the longest cut in a batch."
+            ),
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help="The amount of padding (in seconds) inserted between "
-            "concatenated cuts. This padding is filled with noise when "
-            "noise augmentation is used.",
+            help=(
+                "The amount of padding (in seconds) inserted between "
+                "concatenated cuts. This padding is filled with noise when "
+                "noise augmentation is used."
+            ),
         )
         group.add_argument(
             "--max-duration",
             type=int,
             default=100.0,
-            help="Maximum pooled recordings duration (seconds) in a "
-            "single batch. You can reduce it if it causes CUDA OOM.",
+            help=(
+                "Maximum pooled recordings duration (seconds) in a "
+                "single batch. You can reduce it if it causes CUDA OOM."
+            ),
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=30,
-            help="The number of buckets for the BucketingSampler"
-            "(you might want to increase it for larger datasets).",
+            help=(
+                "The number of buckets for the BucketingSampler"
+                "(you might want to increase it for larger datasets)."
+            ),
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help="When enabled, use on-the-fly cut mixing and feature "
-            "extraction. Will drop existing precomputed feature manifests "
-            "if available.",
+            help=(
+                "When enabled, use on-the-fly cut mixing and feature "
+                "extraction. Will drop existing precomputed feature manifests "
+                "if available."
+            ),
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help="When enabled (=default), the examples will be "
-            "shuffled for each epoch.",
+            help=(
+                "When enabled (=default), the examples will be shuffled for each epoch."
+            ),
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=8,
-            help="The number of training dataloader workers that "
-            "collect the batches.",
+            help="The number of training dataloader workers that collect the batches.",
         )
         group.add_argument(
             "--enable-spec-aug",
@@ -157,10 +173,12 @@ class SPGISpeechAsrDataModule:
             "--spec-aug-time-warp-factor",
             type=int,
             default=80,
-            help="Used only when --enable-spec-aug is True. "
-            "It specifies the factor for time warping in SpecAugment. "
-            "Larger values mean more warping. "
-            "A value less than 1 means to disable time warp.",
+            help=(
+                "Used only when --enable-spec-aug is True. "
+                "It specifies the factor for time warping in SpecAugment. "
+                "Larger values mean more warping. "
+                "A value less than 1 means to disable time warp."
+            ),
         )
 
     def train_dataloaders(
@@ -176,24 +194,20 @@ class SPGISpeechAsrDataModule:
             The state dict for the training sampler.
         """
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(
-            self.args.manifest_dir / "cuts_musan.jsonl.gz"
-        )
+        cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.jsonl.gz")
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(
-                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
-                )
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
             )
         else:
             logging.info("Disable MUSAN")
 
         if self.args.concatenate_cuts:
             logging.info(
-                f"Using cut concatenation with duration factor "
+                "Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -208,9 +222,7 @@ class SPGISpeechAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
             input_transforms.append(
                 SpecAugment(
                     time_warp_factor=self.args.spec_aug_time_warp_factor,
@@ -227,9 +239,7 @@ class SPGISpeechAsrDataModule:
         if self.args.on_the_fly_feats:
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
             )
         else:
@@ -282,9 +292,7 @@ class SPGISpeechAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
             )
         else:
             validate = K2SpeechRecognitionDataset(
@@ -328,9 +336,7 @@ class SPGISpeechAsrDataModule:
     @lru_cache()
     def train_cuts(self) -> CutSet:
         logging.info("About to get SPGISpeech train cuts")
-        return load_manifest_lazy(
-            self.args.manifest_dir / "cuts_train_shuf.jsonl.gz"
-        )
+        return load_manifest_lazy(self.args.manifest_dir / "cuts_train_shuf.jsonl.gz")
 
     @lru_cache()
     def dev_cuts(self) -> CutSet:
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py
index c39bd0530..72a7cd1c1 100755
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py
@@ -76,11 +76,7 @@ from beam_search import (
 )
 from train import get_params, get_transducer_model
 
-from icefall.checkpoint import (
-    average_checkpoints,
-    find_checkpoints,
-    load_checkpoint,
-)
+from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
 from icefall.utils import (
     AttributeDict,
     setup_logger,
@@ -117,9 +113,11 @@ def get_parser():
         "--avg",
         type=int,
         default=10,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch' and '--iter'",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch' and '--iter'"
+        ),
     )
 
     parser.add_argument(
@@ -187,8 +185,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -246,9 +243,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -263,10 +258,7 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -312,11 +304,7 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            (
-                f"beam_{params.beam}_"
-                f"max_contexts_{params.max_contexts}_"
-                f"max_states_{params.max_states}"
-            ): hyps
+            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -389,9 +377,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -424,9 +410,7 @@ def save_results(
         # we also compute CER for spgispeech dataset.
         results_char = []
         for res in results:
-            results_char.append(
-                (res[0], list("".join(res[1])), list("".join(res[2])))
-            )
+            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
         cers_filename = (
             params.res_dir / f"cers-{test_set_name}-{key}-{params.suffix}.txt"
         )
@@ -438,32 +422,23 @@ def save_results(
 
         logging.info("Wrote detailed error stats to {}".format(wers_filename))
 
-    test_set_wers = {
-        k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])
-    }
-    test_set_cers = {
-        k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])
-    }
+    test_set_wers = {k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])}
+    test_set_cers = {k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])}
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER\tCER", file=f)
         for key in test_set_wers:
             print(
-                "{}\t{}\t{}".format(
-                    key, test_set_wers[key], test_set_cers[key]
-                ),
+                "{}\t{}\t{}".format(key, test_set_wers[key], test_set_cers[key]),
                 file=f,
             )
 
     s = "\nFor {}, WER/CER of different settings are:\n".format(test_set_name)
     note = "\tbest for {}".format(test_set_name)
     for key in test_set_wers:
-        s += "{}\t{}\t{}{}\n".format(
-            key, test_set_wers[key], test_set_cers[key], note
-        )
+        s += "{}\t{}\t{}{}\n".format(key, test_set_wers[key], test_set_cers[key], note)
         note = ""
     logging.info(s)
 
@@ -496,9 +471,7 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += (
-            f"-{params.decoding_method}-beam-size-{params.beam_size}"
-        )
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@@ -530,8 +503,7 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for"
-                f" --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py
index 77faa3c0e..1f18ae2f3 100755
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py
@@ -50,11 +50,7 @@ import sentencepiece as spm
 import torch
 from train import get_params, get_transducer_model
 
-from icefall.checkpoint import (
-    average_checkpoints,
-    find_checkpoints,
-    load_checkpoint,
-)
+from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
 from icefall.utils import str2bool
 
 
@@ -67,17 +63,20 @@ def get_parser():
         "--epoch",
         type=int,
         default=28,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=15,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -119,8 +118,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
@@ -196,9 +194,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py
index dda29b3e5..cd835a7b4 100755
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py
@@ -77,9 +77,7 @@ from icefall.dist import cleanup_dist, setup_dist
 from icefall.env import get_env_info
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[
-    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
-]
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
 
 
 def get_parser():
@@ -155,8 +153,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need to be "
-        "changed.",
+        help="The initial learning rate.  This value should not need to be changed.",
     )
 
     parser.add_argument(
@@ -179,42 +176,45 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
         "--prune-range",
         type=int,
         default=5,
-        help="The prune range for rnnt loss, it means how many symbols(context)"
-        "we are using to compute the loss",
+        help=(
+            "The prune range for rnnt loss, it means how many symbols(context)"
+            "we are using to compute the loss"
+        ),
     )
 
     parser.add_argument(
         "--lm-scale",
         type=float,
         default=0.25,
-        help="The scale to smooth the loss with lm "
-        "(output of prediction network) part.",
+        help=(
+            "The scale to smooth the loss with lm (output of prediction network) part."
+        ),
     )
 
     parser.add_argument(
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)"
-        "part.",
+        help="The scale to smooth the loss with am (output of encoder network)part.",
     )
 
     parser.add_argument(
         "--simple-loss-scale",
         type=float,
         default=0.5,
-        help="To get pruning ranges, we will calculate a simple version"
-        "loss(joiner is just addition), this simple loss also uses for"
-        "training (as a regularization item). We will scale the simple loss"
-        "with this parameter before adding to the final loss.",
+        help=(
+            "To get pruning ranges, we will calculate a simple version"
+            "loss(joiner is just addition), this simple loss also uses for"
+            "training (as a regularization item). We will scale the simple loss"
+            "with this parameter before adding to the final loss."
+        ),
     )
 
     parser.add_argument(
@@ -554,23 +554,16 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0
-            if warmup < 1.0
-            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
-        )
-        loss = (
-            params.simple_loss_scale * simple_loss
-            + pruned_loss_scale * pruned_loss
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
         )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
 
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -733,9 +726,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
diff --git a/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py b/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py
index 4582609ac..602e50d29 100755
--- a/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py
+++ b/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py
@@ -84,9 +84,7 @@ def compute_fbank_tal_csasr(num_mel_bins: int = 80):
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set
-                    + cut_set.perturb_speed(0.9)
-                    + cut_set.perturb_speed(1.1)
+                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -112,9 +110,7 @@ def get_args():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/tal_csasr/ASR/local/prepare_char.py b/egs/tal_csasr/ASR/local/prepare_char.py
index 2c5b8b8b3..1262baf63 100755
--- a/egs/tal_csasr/ASR/local/prepare_char.py
+++ b/egs/tal_csasr/ASR/local/prepare_char.py
@@ -87,9 +87,7 @@ def lexicon_to_fst_no_sil(
         cur_state = loop_state
 
         word = word2id[word]
-        pieces = [
-            token2id[i] if i in token2id else token2id[""] for i in pieces
-        ]
+        pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces]
 
         for i in range(len(pieces) - 1):
             w = word if i == 0 else eps
diff --git a/egs/tal_csasr/ASR/local/prepare_lang.py b/egs/tal_csasr/ASR/local/prepare_lang.py
index e5ae89ec4..c8cf9b881 100755
--- a/egs/tal_csasr/ASR/local/prepare_lang.py
+++ b/egs/tal_csasr/ASR/local/prepare_lang.py
@@ -317,9 +317,7 @@ def lexicon_to_fst(
 
 def get_args():
     parser = argparse.ArgumentParser()
-    parser.add_argument(
-        "--lang-dir", type=str, help="The lang dir, data/lang_phone"
-    )
+    parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
     return parser.parse_args()
 
 
diff --git a/egs/tal_csasr/ASR/local/test_prepare_lang.py b/egs/tal_csasr/ASR/local/test_prepare_lang.py
index d4cf62bba..74e025ad7 100755
--- a/egs/tal_csasr/ASR/local/test_prepare_lang.py
+++ b/egs/tal_csasr/ASR/local/test_prepare_lang.py
@@ -88,9 +88,7 @@ def test_read_lexicon(filename: str):
     fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa.draw("L.pdf", title="L")
 
-    fsa_disambig = lexicon_to_fst(
-        lexicon_disambig, phone2id=phone2id, word2id=word2id
-    )
+    fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
     fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
     fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
diff --git a/egs/tal_csasr/ASR/local/text2token.py b/egs/tal_csasr/ASR/local/text2token.py
index 71be2a613..2be639b7a 100755
--- a/egs/tal_csasr/ASR/local/text2token.py
+++ b/egs/tal_csasr/ASR/local/text2token.py
@@ -50,15 +50,15 @@ def get_parser():
         "-n",
         default=1,
         type=int,
-        help="number of characters to split, i.e., \
-                        aabb -> a a b b with -n 1 and aa bb with -n 2",
+        help=(
+            "number of characters to split, i.e.,                         aabb -> a a b"
+            " b with -n 1 and aa bb with -n 2"
+        ),
     )
     parser.add_argument(
         "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
     )
-    parser.add_argument(
-        "--space", default="", type=str, help="space symbol"
-    )
+    parser.add_argument("--space", default="", type=str, help="space symbol")
     parser.add_argument(
         "--non-lang-syms",
         "-l",
@@ -66,9 +66,7 @@ def get_parser():
         type=str,
         help="list of non-linguistic symobles, e.g.,  etc.",
     )
-    parser.add_argument(
-        "text", type=str, default=False, nargs="?", help="input text"
-    )
+    parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
     parser.add_argument(
         "--trans_type",
         "-t",
@@ -108,8 +106,7 @@ def token2id(
             if token_type == "lazy_pinyin":
                 text = lazy_pinyin(chars_list)
                 sub_ids = [
-                    token_table[txt] if txt in token_table else oov_id
-                    for txt in text
+                    token_table[txt] if txt in token_table else oov_id for txt in text
                 ]
                 ids.append(sub_ids)
             else:  # token_type = "pinyin"
@@ -135,9 +132,7 @@ def main():
     if args.text:
         f = codecs.open(args.text, encoding="utf-8")
     else:
-        f = codecs.getreader("utf-8")(
-            sys.stdin if is_python2 else sys.stdin.buffer
-        )
+        f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
 
     sys.stdout = codecs.getwriter("utf-8")(
         sys.stdout if is_python2 else sys.stdout.buffer
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py
index 49bfb148b..02bd6e2cc 100644
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py
@@ -74,10 +74,12 @@ class TAL_CSASRAsrDataModule:
     def add_arguments(cls, parser: argparse.ArgumentParser):
         group = parser.add_argument_group(
             title="ASR data related options",
-            description="These options are used for the preparation of "
-            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-            "effective batch sizes, sampling strategies, applied data "
-            "augmentations, etc.",
+            description=(
+                "These options are used for the preparation of "
+                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+                "effective batch sizes, sampling strategies, applied data "
+                "augmentations, etc."
+            ),
         )
 
         group.add_argument(
@@ -91,66 +93,81 @@ class TAL_CSASRAsrDataModule:
             "--max-duration",
             type=int,
             default=200.0,
-            help="Maximum pooled recordings duration (seconds) in a "
-            "single batch. You can reduce it if it causes CUDA OOM.",
+            help=(
+                "Maximum pooled recordings duration (seconds) in a "
+                "single batch. You can reduce it if it causes CUDA OOM."
+            ),
         )
 
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=True,
-            help="When enabled, the batches will come from buckets of "
-            "similar duration (saves padding frames).",
+            help=(
+                "When enabled, the batches will come from buckets of "
+                "similar duration (saves padding frames)."
+            ),
         )
 
         group.add_argument(
             "--num-buckets",
             type=int,
             default=300,
-            help="The number of buckets for the DynamicBucketingSampler"
-            "(you might want to increase it for larger datasets).",
+            help=(
+                "The number of buckets for the DynamicBucketingSampler"
+                "(you might want to increase it for larger datasets)."
+            ),
         )
 
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help="When enabled, utterances (cuts) will be concatenated "
-            "to minimize the amount of padding.",
+            help=(
+                "When enabled, utterances (cuts) will be concatenated "
+                "to minimize the amount of padding."
+            ),
         )
 
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help="Determines the maximum duration of a concatenated cut "
-            "relative to the duration of the longest cut in a batch.",
+            help=(
+                "Determines the maximum duration of a concatenated cut "
+                "relative to the duration of the longest cut in a batch."
+            ),
         )
 
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help="The amount of padding (in seconds) inserted between "
-            "concatenated cuts. This padding is filled with noise when "
-            "noise augmentation is used.",
+            help=(
+                "The amount of padding (in seconds) inserted between "
+                "concatenated cuts. This padding is filled with noise when "
+                "noise augmentation is used."
+            ),
         )
 
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help="When enabled, use on-the-fly cut mixing and feature "
-            "extraction. Will drop existing precomputed feature manifests "
-            "if available.",
+            help=(
+                "When enabled, use on-the-fly cut mixing and feature "
+                "extraction. Will drop existing precomputed feature manifests "
+                "if available."
+            ),
         )
 
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help="When enabled (=default), the examples will be "
-            "shuffled for each epoch.",
+            help=(
+                "When enabled (=default), the examples will be shuffled for each epoch."
+            ),
         )
 
         group.add_argument(
@@ -164,17 +181,18 @@ class TAL_CSASRAsrDataModule:
             "--return-cuts",
             type=str2bool,
             default=True,
-            help="When enabled, each batch will have the "
-            "field: batch['supervisions']['cut'] with the cuts that "
-            "were used to construct it.",
+            help=(
+                "When enabled, each batch will have the "
+                "field: batch['supervisions']['cut'] with the cuts that "
+                "were used to construct it."
+            ),
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that "
-            "collect the batches.",
+            help="The number of training dataloader workers that collect the batches.",
         )
 
         group.add_argument(
@@ -188,18 +206,22 @@ class TAL_CSASRAsrDataModule:
             "--spec-aug-time-warp-factor",
             type=int,
             default=80,
-            help="Used only when --enable-spec-aug is True. "
-            "It specifies the factor for time warping in SpecAugment. "
-            "Larger values mean more warping. "
-            "A value less than 1 means to disable time warp.",
+            help=(
+                "Used only when --enable-spec-aug is True. "
+                "It specifies the factor for time warping in SpecAugment. "
+                "Larger values mean more warping. "
+                "A value less than 1 means to disable time warp."
+            ),
         )
 
         group.add_argument(
             "--enable-musan",
             type=str2bool,
             default=True,
-            help="When enabled, select noise from MUSAN and mix it"
-            "with training dataset. ",
+            help=(
+                "When enabled, select noise from MUSAN and mix it"
+                "with training dataset. "
+            ),
         )
 
         group.add_argument(
@@ -222,24 +244,20 @@ class TAL_CSASRAsrDataModule:
             The state dict for the training sampler.
         """
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(
-            self.args.manifest_dir / "musan_cuts.jsonl.gz"
-        )
+        cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(
-                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
-                )
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
             )
         else:
             logging.info("Disable MUSAN")
 
         if self.args.concatenate_cuts:
             logging.info(
-                f"Using cut concatenation with duration factor "
+                "Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -254,9 +272,7 @@ class TAL_CSASRAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -300,9 +316,7 @@ class TAL_CSASRAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -360,9 +374,7 @@ class TAL_CSASRAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py
index b624913f5..b2aef7e86 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py
@@ -124,20 +124,24 @@ def get_parser():
         "--avg",
         type=int,
         default=15,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch' and '--iter'",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch' and '--iter'"
+        ),
     )
 
     parser.add_argument(
         "--use-averaged-model",
         type=str2bool,
         default=False,
-        help="Whether to load averaged model. Currently it only supports "
-        "using --epoch. If True, it would decode with the averaged model "
-        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
-        "Actually only the models with epoch number of `epoch-avg` and "
-        "`epoch` are loaded for averaging. ",
+        help=(
+            "Whether to load averaged model. Currently it only supports "
+            "using --epoch. If True, it would decode with the averaged model "
+            "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+            "Actually only the models with epoch number of `epoch-avg` and "
+            "`epoch` are loaded for averaging. "
+        ),
     )
 
     parser.add_argument(
@@ -208,8 +212,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -268,9 +271,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
     zh_hyps = []
     en_hyps = []
@@ -303,10 +304,7 @@ def decode_one_batch(
             hyps.append(chars_new)
             zh_hyps.append(zh_text)
             en_hyps.append(en_text)
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -375,9 +373,7 @@ def decode_one_batch(
                     f"Unsupported decoding method: {params.decoding_method}"
                 )
             for i in range(encoder_out.size(0)):
-                hyp = sp.decode(
-                    [lexicon.token_table[idx] for idx in hyp_tokens[i]]
-                )
+                hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]])
                 chars = pattern.split(hyp.upper())
                 chars_new = []
                 zh_text = []
@@ -396,11 +392,11 @@ def decode_one_batch(
         return {"greedy_search": (hyps, zh_hyps, en_hyps)}
     elif params.decoding_method == "fast_beam_search":
         return {
-            (
-                f"beam_{params.beam}_"
-                f"max_contexts_{params.max_contexts}_"
-                f"max_states_{params.max_states}"
-            ): (hyps, zh_hyps, en_hyps)
+            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": (
+                hyps,
+                zh_hyps,
+                en_hyps,
+            )
         }
     else:
         return {f"beam_size_{params.beam_size}": (hyps, zh_hyps, en_hyps)}
@@ -506,9 +502,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results, zh_results, en_results
 
 
@@ -541,8 +535,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -585,9 +578,7 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += (
-            f"-{params.decoding_method}-beam-size-{params.beam_size}"
-        )
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@@ -619,13 +610,12 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for"
-                    f" --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg:
                 raise ValueError(
@@ -648,13 +638,12 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg + 1]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for"
-                    f" --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg + 1:
                 raise ValueError(
@@ -682,7 +671,7 @@ def main():
             filename_start = f"{params.exp_dir}/epoch-{start}.pt"
             filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
             logging.info(
-                f"Calculating the averaged model over epoch range from "
+                "Calculating the averaged model over epoch range from "
                 f"{start} (excluded) to {params.epoch}"
             )
             model.to(device)
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py
index 8f900208a..94a4c7a2e 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py
@@ -92,20 +92,24 @@ def get_parser():
         "--avg",
         type=int,
         default=15,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch' and '--iter'",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch' and '--iter'"
+        ),
     )
 
     parser.add_argument(
         "--use-averaged-model",
         type=str2bool,
         default=False,
-        help="Whether to load averaged model. Currently it only supports "
-        "using --epoch. If True, it would decode with the averaged model "
-        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
-        "Actually only the models with epoch number of `epoch-avg` and "
-        "`epoch` are loaded for averaging. ",
+        help=(
+            "Whether to load averaged model. Currently it only supports "
+            "using --epoch. If True, it would decode with the averaged model "
+            "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+            "Actually only the models with epoch number of `epoch-avg` and "
+            "`epoch` are loaded for averaging. "
+        ),
     )
 
     parser.add_argument(
@@ -139,8 +143,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     add_model_arguments(parser)
@@ -176,13 +179,12 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for"
-                    f" --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg:
                 raise ValueError(
@@ -205,13 +207,12 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg + 1]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for"
-                    f" --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg + 1:
                 raise ValueError(
@@ -239,7 +240,7 @@ def main():
             filename_start = f"{params.exp_dir}/epoch-{start}.pt"
             filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
             logging.info(
-                f"Calculating the averaged model over epoch range from "
+                "Calculating the averaged model over epoch range from "
                 f"{start} (excluded) to {params.epoch}"
             )
             model.to(device)
@@ -277,9 +278,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py
index dbe213b24..198242129 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py
@@ -84,9 +84,11 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help="Path to the checkpoint. "
-        "The checkpoint is assumed to be saved by "
-        "icefall.checkpoint.save_checkpoint().",
+        help=(
+            "Path to the checkpoint. "
+            "The checkpoint is assumed to be saved by "
+            "icefall.checkpoint.save_checkpoint()."
+        ),
     )
 
     parser.add_argument(
@@ -115,10 +117,12 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     parser.add_argument(
@@ -165,8 +169,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -197,10 +200,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -263,15 +265,11 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=features, x_lens=feature_lengths
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
 
     num_waves = encoder_out.size(0)
     hyps = []
@@ -367,9 +365,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py
index ca35eba45..676e8c904 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py
@@ -86,9 +86,7 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[
-    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
-]
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
 
 
 def add_model_arguments(parser: argparse.ArgumentParser):
@@ -214,8 +212,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need "
-        "to be changed.",
+        help="The initial learning rate.  This value should not need to be changed.",
     )
 
     parser.add_argument(
@@ -238,42 +235,45 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
         "--prune-range",
         type=int,
         default=5,
-        help="The prune range for rnnt loss, it means how many symbols(context)"
-        "we are using to compute the loss",
+        help=(
+            "The prune range for rnnt loss, it means how many symbols(context)"
+            "we are using to compute the loss"
+        ),
     )
 
     parser.add_argument(
         "--lm-scale",
         type=float,
         default=0.25,
-        help="The scale to smooth the loss with lm "
-        "(output of prediction network) part.",
+        help=(
+            "The scale to smooth the loss with lm (output of prediction network) part."
+        ),
     )
 
     parser.add_argument(
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)"
-        "part.",
+        help="The scale to smooth the loss with am (output of encoder network)part.",
     )
 
     parser.add_argument(
         "--simple-loss-scale",
         type=float,
         default=0.5,
-        help="To get pruning ranges, we will calculate a simple version"
-        "loss(joiner is just addition), this simple loss also uses for"
-        "training (as a regularization item). We will scale the simple loss"
-        "with this parameter before adding to the final loss.",
+        help=(
+            "To get pruning ranges, we will calculate a simple version"
+            "loss(joiner is just addition), this simple loss also uses for"
+            "training (as a regularization item). We will scale the simple loss"
+            "with this parameter before adding to the final loss."
+        ),
     )
 
     parser.add_argument(
@@ -600,11 +600,7 @@ def compute_loss(
      warmup: a floating point value which increases throughout training;
         values >= 1.0 are fully warmed up and have all modules present.
     """
-    device = (
-        model.device
-        if isinstance(model, DDP)
-        else next(model.parameters()).device
-    )
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
     feature = batch["inputs"]
     # at entry, feature is (N, T, C)
     assert feature.ndim == 3
@@ -634,22 +630,15 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0
-            if warmup < 1.0
-            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
-        )
-        loss = (
-            params.simple_loss_scale * simple_loss
-            + pruned_loss_scale * pruned_loss
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
         )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -828,9 +817,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -944,7 +931,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2 ** 22
+            2**22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
diff --git a/egs/tedlium3/ASR/local/compute_fbank_tedlium.py b/egs/tedlium3/ASR/local/compute_fbank_tedlium.py
index 327962a79..733ebf235 100755
--- a/egs/tedlium3/ASR/local/compute_fbank_tedlium.py
+++ b/egs/tedlium3/ASR/local/compute_fbank_tedlium.py
@@ -83,9 +83,7 @@ def compute_fbank_tedlium():
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set
-                    + cut_set.perturb_speed(0.9)
-                    + cut_set.perturb_speed(1.1)
+                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
                 )
             cur_num_jobs = num_jobs if ex is None else 80
             cur_num_jobs = min(cur_num_jobs, len(cut_set))
@@ -104,9 +102,7 @@ def compute_fbank_tedlium():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py b/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py
index 49544ccb3..9dbcc9d9e 100644
--- a/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py
+++ b/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py
@@ -25,9 +25,7 @@ import sentencepiece as spm
 
 def get_args():
     parser = argparse.ArgumentParser()
-    parser.add_argument(
-        "--texts", type=List[str], help="The input transcripts list."
-    )
+    parser.add_argument("--texts", type=List[str], help="The input transcripts list.")
     parser.add_argument(
         "--bpe-model",
         type=str,
diff --git a/egs/tedlium3/ASR/local/prepare_lexicon.py b/egs/tedlium3/ASR/local/prepare_lexicon.py
index 35dd332e8..b9160b6d4 100755
--- a/egs/tedlium3/ASR/local/prepare_lexicon.py
+++ b/egs/tedlium3/ASR/local/prepare_lexicon.py
@@ -23,11 +23,12 @@ consisting of supervisions_train.json and does the following:
 1. Generate lexicon_words.txt.
 
 """
-import lhotse
 import argparse
 import logging
 from pathlib import Path
 
+import lhotse
+
 
 def get_args():
     parser = argparse.ArgumentParser()
@@ -61,9 +62,7 @@ def prepare_lexicon(manifests_dir: str, lang_dir: str):
     words = set()
 
     lexicon = Path(lang_dir) / "lexicon_words.txt"
-    sups = lhotse.load_manifest(
-        f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz"
-    )
+    sups = lhotse.load_manifest(f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz")
     for s in sups:
         # list the words units and filter the empty item
         words_list = list(filter(None, s.text.split()))
@@ -88,9 +87,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/tedlium3/ASR/local/prepare_transcripts.py b/egs/tedlium3/ASR/local/prepare_transcripts.py
index 1039ac5bb..7ea4e89a4 100755
--- a/egs/tedlium3/ASR/local/prepare_transcripts.py
+++ b/egs/tedlium3/ASR/local/prepare_transcripts.py
@@ -23,11 +23,12 @@ consisting of supervisions_train.json and does the following:
 1. Generate train.text.
 
 """
-import lhotse
 import argparse
 import logging
 from pathlib import Path
 
+import lhotse
+
 
 def get_args():
     parser = argparse.ArgumentParser()
@@ -61,9 +62,7 @@ def prepare_transcripts(manifests_dir: str, lang_dir: str):
     texts = []
 
     train_text = Path(lang_dir) / "train.text"
-    sups = lhotse.load_manifest(
-        f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz"
-    )
+    sups = lhotse.load_manifest(f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz")
     for s in sups:
         texts.append(s.text)
 
@@ -83,9 +82,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py
index 2b294e601..6bae33e65 100755
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py
@@ -94,17 +94,20 @@ def get_parser():
         "--epoch",
         type=int,
         default=29,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=13,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -172,8 +175,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -231,9 +233,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -248,10 +248,7 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -297,11 +294,7 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            (
-                f"beam_{params.beam}_"
-                f"max_contexts_{params.max_contexts}_"
-                f"max_states_{params.max_states}"
-            ): hyps
+            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -374,9 +367,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -409,8 +400,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/export.py b/egs/tedlium3/ASR/pruned_transducer_stateless/export.py
index a1c3bcea3..244740932 100644
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/export.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/export.py
@@ -65,17 +65,20 @@ def get_parser():
         "--epoch",
         type=int,
         default=30,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=13,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -106,8 +109,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
@@ -179,9 +181,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py
index 8480ac029..00545f107 100644
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py
@@ -93,9 +93,11 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help="Path to the checkpoint. "
-        "The checkpoint is assumed to be saved by "
-        "icefall.checkpoint.save_checkpoint().",
+        help=(
+            "Path to the checkpoint. "
+            "The checkpoint is assumed to be saved by "
+            "icefall.checkpoint.save_checkpoint()."
+        ),
     )
 
     parser.add_argument(
@@ -122,10 +124,12 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     parser.add_argument(
@@ -165,8 +169,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -203,10 +206,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -271,9 +273,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -298,10 +298,7 @@ def main():
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -353,9 +350,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/train.py b/egs/tedlium3/ASR/pruned_transducer_stateless/train.py
index 8d5cdf683..70c5e290f 100755
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/train.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/train.py
@@ -133,42 +133,45 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
         "--prune-range",
         type=int,
         default=5,
-        help="The prune range for rnnt loss, it means how many symbols(context)"
-        "we are using to compute the loss",
+        help=(
+            "The prune range for rnnt loss, it means how many symbols(context)"
+            "we are using to compute the loss"
+        ),
     )
 
     parser.add_argument(
         "--lm-scale",
         type=float,
         default=0.25,
-        help="The scale to smooth the loss with lm "
-        "(output of prediction network) part.",
+        help=(
+            "The scale to smooth the loss with lm (output of prediction network) part."
+        ),
     )
 
     parser.add_argument(
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)"
-        "part.",
+        help="The scale to smooth the loss with am (output of encoder network)part.",
     )
 
     parser.add_argument(
         "--simple-loss-scale",
         type=float,
         default=0.5,
-        help="To get pruning ranges, we will calculate a simple version"
-        "loss(joiner is just addition), this simple loss also uses for"
-        "training (as a regularization item). We will scale the simple loss"
-        "with this parameter before adding to the final loss.",
+        help=(
+            "To get pruning ranges, we will calculate a simple version"
+            "loss(joiner is just addition), this simple loss also uses for"
+            "training (as a regularization item). We will scale the simple loss"
+            "with this parameter before adding to the final loss."
+        ),
     )
 
     parser.add_argument(
@@ -556,9 +559,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -678,9 +679,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py
index 51de46ae8..86ac2fea3 100644
--- a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py
+++ b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py
@@ -63,10 +63,12 @@ class TedLiumAsrDataModule:
     def add_arguments(cls, parser: argparse.ArgumentParser):
         group = parser.add_argument_group(
             title="ASR data related options",
-            description="These options are used for the preparation of "
-            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-            "effective batch sizes, sampling strategies, applied data "
-            "augmentations, etc.",
+            description=(
+                "These options are used for the preparation of "
+                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+                "effective batch sizes, sampling strategies, applied data "
+                "augmentations, etc."
+            ),
         )
         group.add_argument(
             "--manifest-dir",
@@ -78,75 +80,91 @@ class TedLiumAsrDataModule:
             "--max-duration",
             type=int,
             default=200.0,
-            help="Maximum pooled recordings duration (seconds) in a "
-            "single batch. You can reduce it if it causes CUDA OOM.",
+            help=(
+                "Maximum pooled recordings duration (seconds) in a "
+                "single batch. You can reduce it if it causes CUDA OOM."
+            ),
         )
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=True,
-            help="When enabled, the batches will come from buckets of "
-            "similar duration (saves padding frames).",
+            help=(
+                "When enabled, the batches will come from buckets of "
+                "similar duration (saves padding frames)."
+            ),
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=30,
-            help="The number of buckets for the DynamicBucketingSampler"
-            "(you might want to increase it for larger datasets).",
+            help=(
+                "The number of buckets for the DynamicBucketingSampler"
+                "(you might want to increase it for larger datasets)."
+            ),
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help="When enabled, utterances (cuts) will be concatenated "
-            "to minimize the amount of padding.",
+            help=(
+                "When enabled, utterances (cuts) will be concatenated "
+                "to minimize the amount of padding."
+            ),
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help="Determines the maximum duration of a concatenated cut "
-            "relative to the duration of the longest cut in a batch.",
+            help=(
+                "Determines the maximum duration of a concatenated cut "
+                "relative to the duration of the longest cut in a batch."
+            ),
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help="The amount of padding (in seconds) inserted between "
-            "concatenated cuts. This padding is filled with noise when "
-            "noise augmentation is used.",
+            help=(
+                "The amount of padding (in seconds) inserted between "
+                "concatenated cuts. This padding is filled with noise when "
+                "noise augmentation is used."
+            ),
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help="When enabled, use on-the-fly cut mixing and feature "
-            "extraction. Will drop existing precomputed feature manifests "
-            "if available.",
+            help=(
+                "When enabled, use on-the-fly cut mixing and feature "
+                "extraction. Will drop existing precomputed feature manifests "
+                "if available."
+            ),
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help="When enabled (=default), the examples will be "
-            "shuffled for each epoch.",
+            help=(
+                "When enabled (=default), the examples will be shuffled for each epoch."
+            ),
         )
         group.add_argument(
             "--return-cuts",
             type=str2bool,
             default=True,
-            help="When enabled, each batch will have the "
-            "field: batch['supervisions']['cut'] with the cuts that "
-            "were used to construct it.",
+            help=(
+                "When enabled, each batch will have the "
+                "field: batch['supervisions']['cut'] with the cuts that "
+                "were used to construct it."
+            ),
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that "
-            "collect the batches.",
+            help="The number of training dataloader workers that collect the batches.",
         )
 
         group.add_argument(
@@ -160,18 +178,22 @@ class TedLiumAsrDataModule:
             "--spec-aug-time-warp-factor",
             type=int,
             default=80,
-            help="Used only when --enable-spec-aug is True. "
-            "It specifies the factor for time warping in SpecAugment. "
-            "Larger values mean more warping. "
-            "A value less than 1 means to disable time warp.",
+            help=(
+                "Used only when --enable-spec-aug is True. "
+                "It specifies the factor for time warping in SpecAugment. "
+                "Larger values mean more warping. "
+                "A value less than 1 means to disable time warp."
+            ),
         )
 
         group.add_argument(
             "--enable-musan",
             type=str2bool,
             default=True,
-            help="When enabled, select noise from MUSAN and mix it"
-            "with training dataset. ",
+            help=(
+                "When enabled, select noise from MUSAN and mix it"
+                "with training dataset. "
+            ),
         )
 
     def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
@@ -179,20 +201,16 @@ class TedLiumAsrDataModule:
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
-            cuts_musan = load_manifest(
-                self.args.manifest_dir / "musan_cuts.jsonl.gz"
-            )
+            cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
             transforms.append(
-                CutMix(
-                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
-                )
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
             )
         else:
             logging.info("Disable MUSAN")
 
         if self.args.concatenate_cuts:
             logging.info(
-                f"Using cut concatenation with duration factor "
+                "Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -207,9 +225,7 @@ class TedLiumAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -253,9 +269,7 @@ class TedLiumAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -300,9 +314,7 @@ class TedLiumAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
@@ -358,13 +370,9 @@ class TedLiumAsrDataModule:
     @lru_cache()
     def dev_cuts(self) -> CutSet:
         logging.info("About to get dev cuts")
-        return load_manifest_lazy(
-            self.args.manifest_dir / "tedlium_cuts_dev.jsonl.gz"
-        )
+        return load_manifest_lazy(self.args.manifest_dir / "tedlium_cuts_dev.jsonl.gz")
 
     @lru_cache()
     def test_cuts(self) -> CutSet:
         logging.info("About to get test cuts")
-        return load_manifest_lazy(
-            self.args.manifest_dir / "tedlium_cuts_test.jsonl.gz"
-        )
+        return load_manifest_lazy(self.args.manifest_dir / "tedlium_cuts_test.jsonl.gz")
diff --git a/egs/tedlium3/ASR/transducer_stateless/beam_search.py b/egs/tedlium3/ASR/transducer_stateless/beam_search.py
index 77caf6460..1f99edaf3 100644
--- a/egs/tedlium3/ASR/transducer_stateless/beam_search.py
+++ b/egs/tedlium3/ASR/transducer_stateless/beam_search.py
@@ -87,9 +87,9 @@ def greedy_search(
         y = logits.argmax().item()
         if y != blank_id and y != unk_id:
             hyp.append(y)
-            decoder_input = torch.tensor(
-                [hyp[-context_size:]], device=device
-            ).reshape(1, context_size)
+            decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape(
+                1, context_size
+            )
 
             decoder_out = model.decoder(decoder_input, need_pad=False)
 
@@ -148,9 +148,7 @@ class HypothesisList(object):
         key = hyp.key
         if key in self:
             old_hyp = self._data[key]  # shallow copy
-            torch.logaddexp(
-                old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
-            )
+            torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob)
         else:
             self._data[key] = hyp
 
@@ -166,9 +164,7 @@ class HypothesisList(object):
           Return the hypothesis that has the largest `log_prob`.
         """
         if length_norm:
-            return max(
-                self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)
-            )
+            return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys))
         else:
             return max(self._data.values(), key=lambda hyp: hyp.log_prob)
 
@@ -344,9 +340,9 @@ def modified_beam_search(
 
     device = model.device
 
-    decoder_input = torch.tensor(
-        [blank_id] * context_size, device=device
-    ).reshape(1, context_size)
+    decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
+        1, context_size
+    )
 
     decoder_out = model.decoder(decoder_input, need_pad=False)
 
@@ -383,9 +379,7 @@ def modified_beam_search(
         decoder_out = model.decoder(decoder_input, need_pad=False)
         # decoder_output is of shape (num_hyps, 1, decoder_output_dim)
 
-        current_encoder_out = current_encoder_out.expand(
-            decoder_out.size(0), 1, -1
-        )
+        current_encoder_out = current_encoder_out.expand(decoder_out.size(0), 1, -1)
 
         logits = model.joiner(
             current_encoder_out,
@@ -454,9 +448,9 @@ def beam_search(
 
     device = model.device
 
-    decoder_input = torch.tensor(
-        [blank_id] * context_size, device=device
-    ).reshape(1, context_size)
+    decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
+        1, context_size
+    )
 
     decoder_out = model.decoder(decoder_input, need_pad=False)
 
diff --git a/egs/tedlium3/ASR/transducer_stateless/decode.py b/egs/tedlium3/ASR/transducer_stateless/decode.py
index d3e9e55e7..12d0e2652 100755
--- a/egs/tedlium3/ASR/transducer_stateless/decode.py
+++ b/egs/tedlium3/ASR/transducer_stateless/decode.py
@@ -81,16 +81,19 @@ def get_parser():
         "--epoch",
         type=int,
         default=29,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=13,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -130,8 +133,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -250,9 +252,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
     batch_size = encoder_out.size(0)
 
@@ -275,9 +275,7 @@ def decode_one_batch(
                 model=model, encoder_out=encoder_out_i, beam=params.beam_size
             )
         else:
-            raise ValueError(
-                f"Unsupported decoding method: {params.decoding_method}"
-            )
+            raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
         hyps.append(sp.decode(hyp).split())
 
     if params.decoding_method == "greedy_search":
@@ -348,9 +346,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -383,8 +379,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
diff --git a/egs/tedlium3/ASR/transducer_stateless/decoder.py b/egs/tedlium3/ASR/transducer_stateless/decoder.py
index f0c6f32b6..f9a3814c6 100644
--- a/egs/tedlium3/ASR/transducer_stateless/decoder.py
+++ b/egs/tedlium3/ASR/transducer_stateless/decoder.py
@@ -90,9 +90,7 @@ class Decoder(nn.Module):
         if self.context_size > 1:
             embedding_out = embedding_out.permute(0, 2, 1)
             if need_pad is True:
-                embedding_out = F.pad(
-                    embedding_out, pad=(self.context_size - 1, 0)
-                )
+                embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
             else:
                 # During inference time, there is no need to do extra padding
                 # as we only need one output
diff --git a/egs/tedlium3/ASR/transducer_stateless/export.py b/egs/tedlium3/ASR/transducer_stateless/export.py
index c32b1d002..0b2ae970b 100644
--- a/egs/tedlium3/ASR/transducer_stateless/export.py
+++ b/egs/tedlium3/ASR/transducer_stateless/export.py
@@ -69,17 +69,20 @@ def get_parser():
         "--epoch",
         type=int,
         default=20,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=10,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -110,8 +113,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
@@ -247,9 +249,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tedlium3/ASR/transducer_stateless/pretrained.py b/egs/tedlium3/ASR/transducer_stateless/pretrained.py
index c0e3bb844..912d65497 100644
--- a/egs/tedlium3/ASR/transducer_stateless/pretrained.py
+++ b/egs/tedlium3/ASR/transducer_stateless/pretrained.py
@@ -82,9 +82,11 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help="Path to the checkpoint. "
-        "The checkpoint is assumed to be saved by "
-        "icefall.checkpoint.save_checkpoint().",
+        help=(
+            "Path to the checkpoint. "
+            "The checkpoint is assumed to be saved by "
+            "icefall.checkpoint.save_checkpoint()."
+        ),
     )
 
     parser.add_argument(
@@ -110,10 +112,12 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     parser.add_argument(
@@ -127,8 +131,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -222,10 +225,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -285,9 +287,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -335,9 +335,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tedlium3/ASR/transducer_stateless/train.py b/egs/tedlium3/ASR/transducer_stateless/train.py
index 09cbf4a00..6fed32e81 100755
--- a/egs/tedlium3/ASR/transducer_stateless/train.py
+++ b/egs/tedlium3/ASR/transducer_stateless/train.py
@@ -133,8 +133,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -525,9 +524,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -647,9 +644,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/timit/ASR/RESULTS.md b/egs/timit/ASR/RESULTS.md
index b78c16b88..d8ceb82b6 100644
--- a/egs/timit/ASR/RESULTS.md
+++ b/egs/timit/ASR/RESULTS.md
@@ -71,4 +71,4 @@ python tdnn_ligru_ctc/decode.py --epoch 25 \
                                --avg 17 \
                                --max-duration 20 \
                                --lang-dir data/lang_phone
-```
\ No newline at end of file
+```
diff --git a/egs/timit/ASR/local/compile_hlg.py b/egs/timit/ASR/local/compile_hlg.py
index 58cab4cf2..32c248d7e 100644
--- a/egs/timit/ASR/local/compile_hlg.py
+++ b/egs/timit/ASR/local/compile_hlg.py
@@ -146,9 +146,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/timit/ASR/local/compute_fbank_timit.py b/egs/timit/ASR/local/compute_fbank_timit.py
index f25786a0c..ecdf10ba9 100644
--- a/egs/timit/ASR/local/compute_fbank_timit.py
+++ b/egs/timit/ASR/local/compute_fbank_timit.py
@@ -85,9 +85,7 @@ def compute_fbank_timit():
             )
             if partition == "TRAIN":
                 cut_set = (
-                    cut_set
-                    + cut_set.perturb_speed(0.9)
-                    + cut_set.perturb_speed(1.1)
+                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -101,9 +99,7 @@ def compute_fbank_timit():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/timit/ASR/local/prepare_lexicon.py b/egs/timit/ASR/local/prepare_lexicon.py
index 04023a9ab..0cf0f0deb 100644
--- a/egs/timit/ASR/local/prepare_lexicon.py
+++ b/egs/timit/ASR/local/prepare_lexicon.py
@@ -62,9 +62,7 @@ def prepare_lexicon(manifests_dir: str, lang_dir: str):
 
     phones = set()
 
-    supervisions_train = (
-        Path(manifests_dir) / "timit_supervisions_TRAIN.jsonl.gz"
-    )
+    supervisions_train = Path(manifests_dir) / "timit_supervisions_TRAIN.jsonl.gz"
     lexicon = Path(lang_dir) / "lexicon.txt"
 
     logging.info(f"Loading {supervisions_train}!")
@@ -97,9 +95,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/timit/ASR/prepare.sh b/egs/timit/ASR/prepare.sh
index ae1b96a68..d11cd3a05 100644
--- a/egs/timit/ASR/prepare.sh
+++ b/egs/timit/ASR/prepare.sh
@@ -20,9 +20,9 @@ stop_stage=100
 #  - $dl_dir/lm
 #      This directory contains the language model(LM) downloaded from
 #      https://huggingface.co/luomingshuang/timit_lm, and the LM is based
-#	     on 39 phones. About how to get these LM files, you can know it 
+#	     on 39 phones. About how to get these LM files, you can know it
 #      from https://github.com/luomingshuang/Train_LM_with_kaldilm.
-#	
+#
 #	    - lm_3_gram.arpa
 #     - lm_4_gram.arpa
 #
diff --git a/egs/timit/ASR/tdnn_ligru_ctc/decode.py b/egs/timit/ASR/tdnn_ligru_ctc/decode.py
index 4f2aa2340..5a59a13ce 100644
--- a/egs/timit/ASR/tdnn_ligru_ctc/decode.py
+++ b/egs/timit/ASR/tdnn_ligru_ctc/decode.py
@@ -57,16 +57,19 @@ def get_parser():
         "--epoch",
         type=int,
         default=19,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=5,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
     parser.add_argument(
         "--method",
@@ -336,9 +339,7 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -400,9 +401,7 @@ def main():
 
     logging.info(f"device: {device}")
 
-    HLG = k2.Fsa.from_dict(
-        torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
-    )
+    HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
     HLG = HLG.to(device)
     assert HLG.requires_grad is False
 
@@ -462,9 +461,7 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save(
-            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
-        )
+        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
         return
 
     model.to(device)
@@ -485,9 +482,7 @@ def main():
         G=G,
     )
 
-    save_results(
-        params=params, test_set_name=test_set, results_dict=results_dict
-    )
+    save_results(params=params, test_set_name=test_set, results_dict=results_dict)
 
     logging.info("Done!")
 
diff --git a/egs/timit/ASR/tdnn_ligru_ctc/model.py b/egs/timit/ASR/tdnn_ligru_ctc/model.py
index 4d2199ace..9a594a969 100644
--- a/egs/timit/ASR/tdnn_ligru_ctc/model.py
+++ b/egs/timit/ASR/tdnn_ligru_ctc/model.py
@@ -16,11 +16,11 @@
 # limitations under the License.
 
 
+from typing import Optional
+
 import torch
 import torch.nn as nn
-
 from torch import Tensor
-from typing import Optional
 
 
 class TdnnLiGRU(nn.Module):
@@ -261,9 +261,7 @@ class LiGRU(torch.nn.Module):
         h = []
         if hx is not None:
             if self.bidirectional:
-                hx = hx.reshape(
-                    self.num_layers, self.batch_size * 2, self.hidden_size
-                )
+                hx = hx.reshape(self.num_layers, self.batch_size * 2, self.hidden_size)
         # Processing the different layers
         for i, ligru_lay in enumerate(self.rnn):
             if hx is not None:
@@ -445,9 +443,7 @@ class LiGRU_Layer(torch.nn.Module):
             if self.drop_mask_cnt + self.batch_size > self.N_drop_masks:
                 self.drop_mask_cnt = 0
                 self.drop_masks = self.drop(
-                    torch.ones(
-                        self.N_drop_masks, self.hidden_size, device=w.device
-                    )
+                    torch.ones(self.N_drop_masks, self.hidden_size, device=w.device)
                 ).data
 
             # Sampling the mask
diff --git a/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py b/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py
index 7da285944..da669bc39 100644
--- a/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py
+++ b/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py
@@ -29,11 +29,7 @@ import torchaudio
 from model import TdnnLiGRU
 from torch.nn.utils.rnn import pad_sequence
 
-from icefall.decode import (
-    get_lattice,
-    one_best_decoding,
-    rescore_with_whole_lattice,
-)
+from icefall.decode import get_lattice, one_best_decoding, rescore_with_whole_lattice
 from icefall.utils import AttributeDict, get_texts
 
 
@@ -46,9 +42,11 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help="Path to the checkpoint. "
-        "The checkpoint is assumed to be saved by "
-        "icefall.checkpoint.save_checkpoint().",
+        help=(
+            "Path to the checkpoint. "
+            "The checkpoint is assumed to be saved by "
+            "icefall.checkpoint.save_checkpoint()."
+        ),
     )
 
     parser.add_argument(
@@ -58,9 +56,7 @@ def get_parser():
         help="Path to words.txt",
     )
 
-    parser.add_argument(
-        "--HLG", type=str, required=True, help="Path to HLG.pt."
-    )
+    parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
 
     parser.add_argument(
         "--method",
@@ -103,10 +99,12 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     return parser
@@ -144,10 +142,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -215,9 +212,7 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
     features = features.permute(0, 2, 1)  # now features is (N, C, T)
 
     with torch.no_grad():
@@ -269,9 +264,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/timit/ASR/tdnn_ligru_ctc/train.py b/egs/timit/ASR/tdnn_ligru_ctc/train.py
index 452c2a7cb..48b7feda0 100644
--- a/egs/timit/ASR/tdnn_ligru_ctc/train.py
+++ b/egs/timit/ASR/tdnn_ligru_ctc/train.py
@@ -449,9 +449,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py
index 1554e987f..d957c22e1 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py
@@ -63,10 +63,12 @@ class TimitAsrDataModule(DataModule):
         super().add_arguments(parser)
         group = parser.add_argument_group(
             title="ASR data related options",
-            description="These options are used for the preparation of "
-            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-            "effective batch sizes, sampling strategies, applied data "
-            "augmentations, etc.",
+            description=(
+                "These options are used for the preparation of "
+                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+                "effective batch sizes, sampling strategies, applied data "
+                "augmentations, etc."
+            ),
         )
         group.add_argument(
             "--feature-dir",
@@ -78,75 +80,91 @@ class TimitAsrDataModule(DataModule):
             "--max-duration",
             type=int,
             default=200.0,
-            help="Maximum pooled recordings duration (seconds) in a "
-            "single batch. You can reduce it if it causes CUDA OOM.",
+            help=(
+                "Maximum pooled recordings duration (seconds) in a "
+                "single batch. You can reduce it if it causes CUDA OOM."
+            ),
         )
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=True,
-            help="When enabled, the batches will come from buckets of "
-            "similar duration (saves padding frames).",
+            help=(
+                "When enabled, the batches will come from buckets of "
+                "similar duration (saves padding frames)."
+            ),
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=30,
-            help="The number of buckets for the DynamicBucketingSampler"
-            "(you might want to increase it for larger datasets).",
+            help=(
+                "The number of buckets for the DynamicBucketingSampler"
+                "(you might want to increase it for larger datasets)."
+            ),
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help="When enabled, utterances (cuts) will be concatenated "
-            "to minimize the amount of padding.",
+            help=(
+                "When enabled, utterances (cuts) will be concatenated "
+                "to minimize the amount of padding."
+            ),
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help="Determines the maximum duration of a concatenated cut "
-            "relative to the duration of the longest cut in a batch.",
+            help=(
+                "Determines the maximum duration of a concatenated cut "
+                "relative to the duration of the longest cut in a batch."
+            ),
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help="The amount of padding (in seconds) inserted between "
-            "concatenated cuts. This padding is filled with noise when "
-            "noise augmentation is used.",
+            help=(
+                "The amount of padding (in seconds) inserted between "
+                "concatenated cuts. This padding is filled with noise when "
+                "noise augmentation is used."
+            ),
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help="When enabled, use on-the-fly cut mixing and feature "
-            "extraction. Will drop existing precomputed feature manifests "
-            "if available.",
+            help=(
+                "When enabled, use on-the-fly cut mixing and feature "
+                "extraction. Will drop existing precomputed feature manifests "
+                "if available."
+            ),
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help="When enabled (=default), the examples will be "
-            "shuffled for each epoch.",
+            help=(
+                "When enabled (=default), the examples will be shuffled for each epoch."
+            ),
         )
         group.add_argument(
             "--return-cuts",
             type=str2bool,
             default=True,
-            help="When enabled, each batch will have the "
-            "field: batch['supervisions']['cut'] with the cuts that "
-            "were used to construct it.",
+            help=(
+                "When enabled, each batch will have the "
+                "field: batch['supervisions']['cut'] with the cuts that "
+                "were used to construct it."
+            ),
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that "
-            "collect the batches.",
+            help="The number of training dataloader workers that collect the batches.",
         )
 
     def train_dataloaders(self) -> DataLoader:
@@ -154,15 +172,13 @@ class TimitAsrDataModule(DataModule):
         cuts_train = self.train_cuts()
 
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(
-            self.args.feature_dir / "musan_cuts.jsonl.gz"
-        )
+        cuts_musan = load_manifest(self.args.feature_dir / "musan_cuts.jsonl.gz")
 
         logging.info("About to create train dataset")
         transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))]
         if self.args.concatenate_cuts:
             logging.info(
-                f"Using cut concatenation with duration factor "
+                "Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -178,9 +194,9 @@ class TimitAsrDataModule(DataModule):
         # In different Lhotse's versions, the default of num_frame_masks is
         # different.
         num_frame_masks = 10
-        num_frame_masks_parameter = inspect.signature(
-            SpecAugment.__init__
-        ).parameters["num_frame_masks"]
+        num_frame_masks_parameter = inspect.signature(SpecAugment.__init__).parameters[
+            "num_frame_masks"
+        ]
         if num_frame_masks_parameter.default == 1:
             num_frame_masks = 2
         logging.info(f"Num frame mask: {num_frame_masks}")
@@ -212,9 +228,7 @@ class TimitAsrDataModule(DataModule):
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -263,9 +277,7 @@ class TimitAsrDataModule(DataModule):
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
@@ -299,20 +311,14 @@ class TimitAsrDataModule(DataModule):
         for cuts_test in cuts:
             logging.debug("About to create test dataset")
             test = K2SpeechRecognitionDataset(
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                )
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
                 if self.args.on_the_fly_feats
                 else PrecomputedFeatures(),
                 return_cuts=self.args.return_cuts,
             )
-            sampler = SingleCutSampler(
-                cuts_test, max_duration=self.args.max_duration
-            )
+            sampler = SingleCutSampler(cuts_test, max_duration=self.args.max_duration)
             logging.debug("About to create test dataloader")
-            test_dl = DataLoader(
-                test, batch_size=None, sampler=sampler, num_workers=1
-            )
+            test_dl = DataLoader(test, batch_size=None, sampler=sampler, num_workers=1)
             test_loaders.append(test_dl)
 
         if is_list:
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/decode.py b/egs/timit/ASR/tdnn_lstm_ctc/decode.py
index 5e7300cf2..319ee5515 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/decode.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/decode.py
@@ -56,16 +56,19 @@ def get_parser():
         "--epoch",
         type=int,
         default=25,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=5,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
     parser.add_argument(
         "--method",
@@ -335,9 +338,7 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -399,9 +400,7 @@ def main():
 
     logging.info(f"device: {device}")
 
-    HLG = k2.Fsa.from_dict(
-        torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
-    )
+    HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
     HLG = HLG.to(device)
     assert HLG.requires_grad is False
 
@@ -461,9 +460,7 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save(
-            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
-        )
+        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
         return
 
     model.to(device)
@@ -483,9 +480,7 @@ def main():
         G=G,
     )
 
-    save_results(
-        params=params, test_set_name=test_set, results_dict=results_dict
-    )
+    save_results(params=params, test_set_name=test_set, results_dict=results_dict)
 
     logging.info("Done!")
 
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/model.py b/egs/timit/ASR/tdnn_lstm_ctc/model.py
index 51edb97e2..e211ad80d 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/model.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/model.py
@@ -74,10 +74,7 @@ class TdnnLstm(nn.Module):
             nn.BatchNorm1d(num_features=512, affine=False),
         )
         self.lstms = nn.ModuleList(
-            [
-                nn.LSTM(input_size=512, hidden_size=512, num_layers=1)
-                for _ in range(4)
-            ]
+            [nn.LSTM(input_size=512, hidden_size=512, num_layers=1) for _ in range(4)]
         )
         self.lstm_bnorms = nn.ModuleList(
             [nn.BatchNorm1d(num_features=512, affine=False) for _ in range(5)]
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py b/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py
index 5f478da1c..0c72c973b 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py
@@ -29,11 +29,7 @@ import torchaudio
 from model import TdnnLstm
 from torch.nn.utils.rnn import pad_sequence
 
-from icefall.decode import (
-    get_lattice,
-    one_best_decoding,
-    rescore_with_whole_lattice,
-)
+from icefall.decode import get_lattice, one_best_decoding, rescore_with_whole_lattice
 from icefall.utils import AttributeDict, get_texts
 
 
@@ -46,9 +42,11 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help="Path to the checkpoint. "
-        "The checkpoint is assumed to be saved by "
-        "icefall.checkpoint.save_checkpoint().",
+        help=(
+            "Path to the checkpoint. "
+            "The checkpoint is assumed to be saved by "
+            "icefall.checkpoint.save_checkpoint()."
+        ),
     )
 
     parser.add_argument(
@@ -58,9 +56,7 @@ def get_parser():
         help="Path to words.txt",
     )
 
-    parser.add_argument(
-        "--HLG", type=str, required=True, help="Path to HLG.pt."
-    )
+    parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
 
     parser.add_argument(
         "--method",
@@ -103,10 +99,12 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     return parser
@@ -144,10 +142,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -215,9 +212,7 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
     features = features.permute(0, 2, 1)  # now features is (N, C, T)
 
     with torch.no_grad():
@@ -269,9 +264,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/train.py b/egs/timit/ASR/tdnn_lstm_ctc/train.py
index 849256b98..be1ecffaa 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/train.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/train.py
@@ -449,9 +449,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py
index 8a9f6ed30..bd73e520e 100755
--- a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py
+++ b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py
@@ -20,12 +20,7 @@ import logging
 from pathlib import Path
 
 import torch
-from lhotse import (
-    CutSet,
-    KaldifeatFbank,
-    KaldifeatFbankConfig,
-    LilcomHdf5Writer,
-)
+from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomHdf5Writer
 
 # Torch's multithreaded behavior needs to be disabled or
 # it wastes a lot of CPU and slow things down.
@@ -83,9 +78,7 @@ def compute_fbank_wenetspeech_dev_test():
 
 
 def main():
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
     logging.basicConfig(format=formatter, level=logging.INFO)
 
     compute_fbank_wenetspeech_dev_test()
diff --git a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py
index a882b6113..c228597b8 100755
--- a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py
+++ b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py
@@ -62,8 +62,10 @@ def get_parser():
         "--batch-duration",
         type=float,
         default=600.0,
-        help="The maximum number of audio seconds in a batch."
-        "Determines batch size dynamically.",
+        help=(
+            "The maximum number of audio seconds in a batch."
+            "Determines batch size dynamically."
+        ),
     )
 
     parser.add_argument(
@@ -152,9 +154,7 @@ def main():
     date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
 
     log_filename = "log-compute_fbank_wenetspeech_splits"
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
     log_filename = f"{log_filename}-{date_time}"
 
     logging.basicConfig(
diff --git a/egs/wenetspeech/ASR/local/prepare_char.py b/egs/wenetspeech/ASR/local/prepare_char.py
index 8bc073c75..d8622842f 100755
--- a/egs/wenetspeech/ASR/local/prepare_char.py
+++ b/egs/wenetspeech/ASR/local/prepare_char.py
@@ -83,9 +83,7 @@ def lexicon_to_fst_no_sil(
         cur_state = loop_state
 
         word = word2id[word]
-        pieces = [
-            token2id[i] if i in token2id else token2id[""] for i in pieces
-        ]
+        pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces]
 
         for i in range(len(pieces) - 1):
             w = word if i == 0 else eps
@@ -138,9 +136,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
     return False
 
 
-def generate_lexicon(
-    token_sym_table: Dict[str, int], words: List[str]
-) -> Lexicon:
+def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
     """Generate a lexicon from a word list and token_sym_table.
     Args:
       token_sym_table:
diff --git a/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py b/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py
index 817969c47..93ce750f8 100755
--- a/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py
+++ b/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py
@@ -115,11 +115,7 @@ def preprocess_wenet_speech():
                 f"Speed perturb for {partition} with factors 0.9 and 1.1 "
                 "(Perturbing may take 8 minutes and saving may take 20 minutes)"
             )
-            cut_set = (
-                cut_set
-                + cut_set.perturb_speed(0.9)
-                + cut_set.perturb_speed(1.1)
-            )
+            cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
         logging.info(f"Saving to {raw_cuts_path}")
         cut_set.to_file(raw_cuts_path)
 
diff --git a/egs/wenetspeech/ASR/local/text2token.py b/egs/wenetspeech/ASR/local/text2token.py
index 1c463cf1c..e121d842c 100755
--- a/egs/wenetspeech/ASR/local/text2token.py
+++ b/egs/wenetspeech/ASR/local/text2token.py
@@ -50,15 +50,15 @@ def get_parser():
         "-n",
         default=1,
         type=int,
-        help="number of characters to split, i.e., \
-                        aabb -> a a b b with -n 1 and aa bb with -n 2",
+        help=(
+            "number of characters to split, i.e.,                         aabb -> a a b"
+            " b with -n 1 and aa bb with -n 2"
+        ),
     )
     parser.add_argument(
         "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
     )
-    parser.add_argument(
-        "--space", default="", type=str, help="space symbol"
-    )
+    parser.add_argument("--space", default="", type=str, help="space symbol")
     parser.add_argument(
         "--non-lang-syms",
         "-l",
@@ -66,9 +66,7 @@ def get_parser():
         type=str,
         help="list of non-linguistic symobles, e.g.,  etc.",
     )
-    parser.add_argument(
-        "text", type=str, default=False, nargs="?", help="input text"
-    )
+    parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
     parser.add_argument(
         "--trans_type",
         "-t",
@@ -108,8 +106,7 @@ def token2id(
             if token_type == "lazy_pinyin":
                 text = lazy_pinyin(chars_list)
                 sub_ids = [
-                    token_table[txt] if txt in token_table else oov_id
-                    for txt in text
+                    token_table[txt] if txt in token_table else oov_id for txt in text
                 ]
                 ids.append(sub_ids)
             else:  # token_type = "pinyin"
@@ -135,9 +132,7 @@ def main():
     if args.text:
         f = codecs.open(args.text, encoding="utf-8")
     else:
-        f = codecs.getreader("utf-8")(
-            sys.stdin if is_python2 else sys.stdin.buffer
-        )
+        f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
 
     sys.stdout = codecs.getwriter("utf-8")(
         sys.stdout if is_python2 else sys.stdout.buffer
diff --git a/egs/wenetspeech/ASR/prepare.sh b/egs/wenetspeech/ASR/prepare.sh
index 755fbb2d7..da7d7e061 100755
--- a/egs/wenetspeech/ASR/prepare.sh
+++ b/egs/wenetspeech/ASR/prepare.sh
@@ -190,7 +190,7 @@ if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then
   mkdir -p $lang_char_dir
 
   if ! which jq; then
-      echo "This script is intended to be used with jq but you have not installed jq 
+      echo "This script is intended to be used with jq but you have not installed jq
       Note: in Linux, you can install jq with the following command:
       1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64
       2. chmod +x ./jq
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
index 10c953e3b..bd92ac115 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -81,10 +81,12 @@ class WenetSpeechAsrDataModule:
     def add_arguments(cls, parser: argparse.ArgumentParser):
         group = parser.add_argument_group(
             title="ASR data related options",
-            description="These options are used for the preparation of "
-            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-            "effective batch sizes, sampling strategies, applied data "
-            "augmentations, etc.",
+            description=(
+                "These options are used for the preparation of "
+                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+                "effective batch sizes, sampling strategies, applied data "
+                "augmentations, etc."
+            ),
         )
         group.add_argument(
             "--manifest-dir",
@@ -96,75 +98,91 @@ class WenetSpeechAsrDataModule:
             "--max-duration",
             type=int,
             default=200.0,
-            help="Maximum pooled recordings duration (seconds) in a "
-            "single batch. You can reduce it if it causes CUDA OOM.",
+            help=(
+                "Maximum pooled recordings duration (seconds) in a "
+                "single batch. You can reduce it if it causes CUDA OOM."
+            ),
         )
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=True,
-            help="When enabled, the batches will come from buckets of "
-            "similar duration (saves padding frames).",
+            help=(
+                "When enabled, the batches will come from buckets of "
+                "similar duration (saves padding frames)."
+            ),
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=300,
-            help="The number of buckets for the DynamicBucketingSampler"
-            "(you might want to increase it for larger datasets).",
+            help=(
+                "The number of buckets for the DynamicBucketingSampler"
+                "(you might want to increase it for larger datasets)."
+            ),
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help="When enabled, utterances (cuts) will be concatenated "
-            "to minimize the amount of padding.",
+            help=(
+                "When enabled, utterances (cuts) will be concatenated "
+                "to minimize the amount of padding."
+            ),
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help="Determines the maximum duration of a concatenated cut "
-            "relative to the duration of the longest cut in a batch.",
+            help=(
+                "Determines the maximum duration of a concatenated cut "
+                "relative to the duration of the longest cut in a batch."
+            ),
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help="The amount of padding (in seconds) inserted between "
-            "concatenated cuts. This padding is filled with noise when "
-            "noise augmentation is used.",
+            help=(
+                "The amount of padding (in seconds) inserted between "
+                "concatenated cuts. This padding is filled with noise when "
+                "noise augmentation is used."
+            ),
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help="When enabled, use on-the-fly cut mixing and feature "
-            "extraction. Will drop existing precomputed feature manifests "
-            "if available.",
+            help=(
+                "When enabled, use on-the-fly cut mixing and feature "
+                "extraction. Will drop existing precomputed feature manifests "
+                "if available."
+            ),
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help="When enabled (=default), the examples will be "
-            "shuffled for each epoch.",
+            help=(
+                "When enabled (=default), the examples will be shuffled for each epoch."
+            ),
         )
         group.add_argument(
             "--return-cuts",
             type=str2bool,
             default=True,
-            help="When enabled, each batch will have the "
-            "field: batch['supervisions']['cut'] with the cuts that "
-            "were used to construct it.",
+            help=(
+                "When enabled, each batch will have the "
+                "field: batch['supervisions']['cut'] with the cuts that "
+                "were used to construct it."
+            ),
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that "
-            "collect the batches.",
+            help="The number of training dataloader workers that collect the batches.",
         )
 
         group.add_argument(
@@ -178,18 +196,22 @@ class WenetSpeechAsrDataModule:
             "--spec-aug-time-warp-factor",
             type=int,
             default=80,
-            help="Used only when --enable-spec-aug is True. "
-            "It specifies the factor for time warping in SpecAugment. "
-            "Larger values mean more warping. "
-            "A value less than 1 means to disable time warp.",
+            help=(
+                "Used only when --enable-spec-aug is True. "
+                "It specifies the factor for time warping in SpecAugment. "
+                "Larger values mean more warping. "
+                "A value less than 1 means to disable time warp."
+            ),
         )
 
         group.add_argument(
             "--enable-musan",
             type=str2bool,
             default=True,
-            help="When enabled, select noise from MUSAN and mix it"
-            "with training dataset. ",
+            help=(
+                "When enabled, select noise from MUSAN and mix it"
+                "with training dataset. "
+            ),
         )
 
         group.add_argument(
@@ -212,24 +234,20 @@ class WenetSpeechAsrDataModule:
             The state dict for the training sampler.
         """
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(
-            self.args.manifest_dir / "musan_cuts.jsonl.gz"
-        )
+        cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(
-                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
-                )
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
             )
         else:
             logging.info("Disable MUSAN")
 
         if self.args.concatenate_cuts:
             logging.info(
-                f"Using cut concatenation with duration factor "
+                "Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
@@ -244,9 +262,7 @@ class WenetSpeechAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -289,9 +305,7 @@ class WenetSpeechAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -348,9 +362,7 @@ class WenetSpeechAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
@@ -414,8 +426,7 @@ class WenetSpeechAsrDataModule:
     def train_cuts(self) -> CutSet:
         logging.info("About to get train cuts")
         cuts_train = load_manifest_lazy(
-            self.args.manifest_dir
-            / f"cuts_{self.args.training_subset}.jsonl.gz"
+            self.args.manifest_dir / f"cuts_{self.args.training_subset}.jsonl.gz"
         )
         return cuts_train
 
@@ -427,13 +438,9 @@ class WenetSpeechAsrDataModule:
     @lru_cache()
     def test_net_cuts(self) -> List[CutSet]:
         logging.info("About to get TEST_NET cuts")
-        return load_manifest_lazy(
-            self.args.manifest_dir / "cuts_TEST_NET.jsonl.gz"
-        )
+        return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_NET.jsonl.gz")
 
     @lru_cache()
     def test_meeting_cuts(self) -> List[CutSet]:
         logging.info("About to get TEST_MEETING cuts")
-        return load_manifest_lazy(
-            self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz"
-        )
+        return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz")
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py
index f0c9bebec..6e856248c 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py
@@ -114,11 +114,7 @@ from beam_search import (
 from train import get_params, get_transducer_model
 
 from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
-from icefall.checkpoint import (
-    average_checkpoints,
-    find_checkpoints,
-    load_checkpoint,
-)
+from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
 from icefall.lexicon import Lexicon
 from icefall.utils import (
     AttributeDict,
@@ -137,25 +133,30 @@ def get_parser():
         "--epoch",
         type=int,
         default=28,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
 
     parser.add_argument(
         "--batch",
         type=int,
         default=None,
-        help="It specifies the batch checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the batch checkpoint to use for decoding."
+            "Note: Epoch counts from 0."
+        ),
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=15,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -252,8 +253,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -328,9 +328,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -389,10 +387,7 @@ def decode_one_batch(
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -438,11 +433,7 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            (
-                f"beam_{params.beam}_"
-                f"max_contexts_{params.max_contexts}_"
-                f"max_states_{params.max_states}"
-            ): hyps
+            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -515,9 +506,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -550,8 +539,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -663,9 +651,7 @@ def main():
             )
             decoding_graph.scores *= params.ngram_lm_scale
         else:
-            decoding_graph = k2.trivial_graph(
-                params.vocab_size - 1, device=device
-            )
+            decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
     else:
         decoding_graph = None
 
@@ -716,8 +702,7 @@ def main():
         )
 
     dev_shards = [
-        str(path)
-        for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
+        str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
     ]
     cuts_dev_webdataset = CutSet.from_webdataset(
         dev_shards,
@@ -727,8 +712,7 @@ def main():
     )
 
     test_net_shards = [
-        str(path)
-        for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
+        str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
     ]
     cuts_test_net_webdataset = CutSet.from_webdataset(
         test_net_shards,
@@ -739,9 +723,7 @@ def main():
 
     test_meeting_shards = [
         str(path)
-        for path in sorted(
-            glob.glob(os.path.join(test_meeting, "shared-*.tar"))
-        )
+        for path in sorted(glob.glob(os.path.join(test_meeting, "shared-*.tar")))
     ]
     cuts_test_meeting_webdataset = CutSet.from_webdataset(
         test_meeting_shards,
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
index 933642a0f..c742593df 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
@@ -126,17 +126,20 @@ def get_parser():
         "--epoch",
         type=int,
         default=28,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=15,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -205,8 +208,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
@@ -468,13 +470,9 @@ def export_joiner_model_onnx(
 
         - projected_decoder_out: a tensor of shape (N, joiner_dim)
     """
-    encoder_proj_filename = str(joiner_filename).replace(
-        ".onnx", "_encoder_proj.onnx"
-    )
+    encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx")
 
-    decoder_proj_filename = str(joiner_filename).replace(
-        ".onnx", "_decoder_proj.onnx"
-    )
+    decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx")
 
     encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
     decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
@@ -645,9 +643,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py
index e5cc47bfe..ed9020c67 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py
@@ -107,10 +107,12 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     parser.add_argument(
@@ -145,10 +147,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -331,9 +332,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py
index c396c50ef..a46ff5a07 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py
@@ -219,9 +219,7 @@ def test_joiner(
         )
 
         # Now test encoder_proj
-        joiner_encoder_proj_inputs = {
-            encoder_proj_input_name: encoder_out.numpy()
-        }
+        joiner_encoder_proj_inputs = {encoder_proj_input_name: encoder_out.numpy()}
         joiner_encoder_proj_out = joiner_encoder_proj_session.run(
             [encoder_proj_output_name], joiner_encoder_proj_inputs
         )[0]
@@ -230,16 +228,10 @@ def test_joiner(
         torch_joiner_encoder_proj_out = model.joiner.encoder_proj(encoder_out)
         assert torch.allclose(
             joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5
-        ), (
-            (joiner_encoder_proj_out - torch_joiner_encoder_proj_out)
-            .abs()
-            .max()
-        )
+        ), ((joiner_encoder_proj_out - torch_joiner_encoder_proj_out).abs().max())
 
         # Now test decoder_proj
-        joiner_decoder_proj_inputs = {
-            decoder_proj_input_name: decoder_out.numpy()
-        }
+        joiner_decoder_proj_inputs = {decoder_proj_input_name: decoder_out.numpy()}
         joiner_decoder_proj_out = joiner_decoder_proj_session.run(
             [decoder_proj_output_name], joiner_decoder_proj_inputs
         )[0]
@@ -248,11 +240,7 @@ def test_joiner(
         torch_joiner_decoder_proj_out = model.joiner.decoder_proj(decoder_out)
         assert torch.allclose(
             joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5
-        ), (
-            (joiner_decoder_proj_out - torch_joiner_decoder_proj_out)
-            .abs()
-            .max()
-        )
+        ), ((joiner_decoder_proj_out - torch_joiner_decoder_proj_out).abs().max())
 
 
 @torch.no_grad()
@@ -304,9 +292,7 @@ def main():
 
 if __name__ == "__main__":
     torch.manual_seed(20220727)
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py
index 3770fbbb4..f7d962008 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py
@@ -111,10 +111,12 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     parser.add_argument(
@@ -149,10 +151,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -200,11 +201,7 @@ def greedy_search(
 
     projected_encoder_out = joiner_encoder_proj.run(
         [joiner_encoder_proj.get_outputs()[0].name],
-        {
-            joiner_encoder_proj.get_inputs()[
-                0
-            ].name: packed_encoder_out.data.numpy()
-        },
+        {joiner_encoder_proj.get_inputs()[0].name: packed_encoder_out.data.numpy()},
     )[0]
 
     blank_id = 0  # hard-code to 0
@@ -389,9 +386,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py
index 9a549efd9..26c9c2b8c 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py
@@ -80,9 +80,11 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help="Path to the checkpoint. "
-        "The checkpoint is assumed to be saved by "
-        "icefall.checkpoint.save_checkpoint().",
+        help=(
+            "Path to the checkpoint. "
+            "The checkpoint is assumed to be saved by "
+            "icefall.checkpoint.save_checkpoint()."
+        ),
     )
 
     parser.add_argument(
@@ -107,10 +109,12 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     parser.add_argument(
@@ -158,8 +162,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -189,10 +192,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -253,9 +255,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -280,10 +280,7 @@ def main():
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -335,9 +332,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py
index d3cc7c9c9..e020c4c05 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py
@@ -115,9 +115,7 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[
-    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
-]
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
 
 
 def get_parser():
@@ -219,42 +217,45 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
         "--prune-range",
         type=int,
         default=5,
-        help="The prune range for rnnt loss, it means how many symbols(context)"
-        "we are using to compute the loss",
+        help=(
+            "The prune range for rnnt loss, it means how many symbols(context)"
+            "we are using to compute the loss"
+        ),
     )
 
     parser.add_argument(
         "--lm-scale",
         type=float,
         default=0.25,
-        help="The scale to smooth the loss with lm "
-        "(output of prediction network) part.",
+        help=(
+            "The scale to smooth the loss with lm (output of prediction network) part."
+        ),
     )
 
     parser.add_argument(
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)"
-        "part.",
+        help="The scale to smooth the loss with am (output of encoder network)part.",
     )
 
     parser.add_argument(
         "--simple-loss-scale",
         type=float,
         default=0.5,
-        help="To get pruning ranges, we will calculate a simple version"
-        "loss(joiner is just addition), this simple loss also uses for"
-        "training (as a regularization item). We will scale the simple loss"
-        "with this parameter before adding to the final loss.",
+        help=(
+            "To get pruning ranges, we will calculate a simple version"
+            "loss(joiner is just addition), this simple loss also uses for"
+            "training (as a regularization item). We will scale the simple loss"
+            "with this parameter before adding to the final loss."
+        ),
     )
 
     parser.add_argument(
@@ -590,22 +591,15 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0
-            if warmup < 1.0
-            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
-        )
-        loss = (
-            params.simple_loss_scale * simple_loss
-            + pruned_loss_scale * pruned_loss
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
         )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -762,9 +756,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -864,7 +856,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2 ** 22
+            2**22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py
index dd27c17f0..1023c931a 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py
@@ -210,10 +210,7 @@ class Conformer(EncoderInterface):
           (num_encoder_layers, cnn_module_kernel - 1, encoder_dim).
           NOTE: the returned tensors are on the given device.
         """
-        if (
-            len(self._init_state) == 2
-            and self._init_state[0].size(1) == left_context
-        ):
+        if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context:
             # Note: It is OK to share the init state as it is
             # not going to be modified by the model
             return self._init_state
@@ -433,9 +430,7 @@ class ConformerEncoderLayer(nn.Module):
 
         self.d_model = d_model
 
-        self.self_attn = RelPositionMultiheadAttention(
-            d_model, nhead, dropout=0.0
-        )
+        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
 
         self.feed_forward = nn.Sequential(
             ScaledLinear(d_model, dim_feedforward),
@@ -453,9 +448,7 @@ class ConformerEncoderLayer(nn.Module):
             ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
         )
 
-        self.conv_module = ConvolutionModule(
-            d_model, cnn_module_kernel, causal=causal
-        )
+        self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal)
 
         self.norm_final = BasicNorm(d_model)
 
@@ -520,9 +513,7 @@ class ConformerEncoderLayer(nn.Module):
         src = src + self.dropout(src_att)
 
         # convolution module
-        conv, _ = self.conv_module(
-            src, src_key_padding_mask=src_key_padding_mask
-        )
+        conv, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
         src = src + self.dropout(conv)
 
         # feed forward module
@@ -766,9 +757,7 @@ class RelPositionalEncoding(torch.nn.Module):
         max_len: Maximum input length.
     """
 
-    def __init__(
-        self, d_model: int, dropout_rate: float, max_len: int = 5000
-    ) -> None:
+    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
         """Construct an PositionalEncoding object."""
         super(RelPositionalEncoding, self).__init__()
         self.d_model = d_model
@@ -784,9 +773,7 @@ class RelPositionalEncoding(torch.nn.Module):
             # the length of self.pe is 2 * input_len - 1
             if self.pe.size(1) >= x_size_1 * 2 - 1:
                 # Note: TorchScript doesn't implement operator== for torch.Device
-                if self.pe.dtype != x.dtype or str(self.pe.device) != str(
-                    x.device
-                ):
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
                     self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                 return
         # Suppose `i` means to the position of query vector and `j` means the
@@ -1073,9 +1060,9 @@ class RelPositionMultiheadAttention(nn.Module):
 
         if torch.equal(query, key) and torch.equal(key, value):
             # self-attention
-            q, k, v = nn.functional.linear(
-                query, in_proj_weight, in_proj_bias
-            ).chunk(3, dim=-1)
+            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
+                3, dim=-1
+            )
 
         elif torch.equal(key, value):
             # encoder-decoder attention
@@ -1144,33 +1131,25 @@ class RelPositionMultiheadAttention(nn.Module):
             if attn_mask.dim() == 2:
                 attn_mask = attn_mask.unsqueeze(0)
                 if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
-                    raise RuntimeError(
-                        "The size of the 2D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
             elif attn_mask.dim() == 3:
                 if list(attn_mask.size()) != [
                     bsz * num_heads,
                     query.size(0),
                     key.size(0),
                 ]:
-                    raise RuntimeError(
-                        "The size of the 3D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
             else:
                 raise RuntimeError(
-                    "attn_mask's dimension {} is not supported".format(
-                        attn_mask.dim()
-                    )
+                    "attn_mask's dimension {} is not supported".format(attn_mask.dim())
                 )
             # attn_mask's dim is 3 now.
 
         # convert ByteTensor key_padding_mask to bool
-        if (
-            key_padding_mask is not None
-            and key_padding_mask.dtype == torch.uint8
-        ):
+        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
             warnings.warn(
-                "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
+                "Byte tensor for key_padding_mask is deprecated. Use bool tensor"
+                " instead."
             )
             key_padding_mask = key_padding_mask.to(torch.bool)
 
@@ -1208,23 +1187,15 @@ class RelPositionMultiheadAttention(nn.Module):
         # first compute matrix a and matrix c
         # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
         k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
-        matrix_ac = torch.matmul(
-            q_with_bias_u, k
-        )  # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
 
         # compute matrix b and matrix d
-        matrix_bd = torch.matmul(
-            q_with_bias_v, p
-        )  # (batch, head, time1, 2*time1-1)
+        matrix_bd = torch.matmul(q_with_bias_v, p)  # (batch, head, time1, 2*time1-1)
         matrix_bd = self.rel_shift(matrix_bd, left_context)
 
-        attn_output_weights = (
-            matrix_ac + matrix_bd
-        )  # (batch, head, time1, time2)
+        attn_output_weights = matrix_ac + matrix_bd  # (batch, head, time1, time2)
 
-        attn_output_weights = attn_output_weights.view(
-            bsz * num_heads, tgt_len, -1
-        )
+        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
 
         assert list(attn_output_weights.size()) == [
             bsz * num_heads,
@@ -1265,21 +1236,17 @@ class RelPositionMultiheadAttention(nn.Module):
         ):
             if attn_mask.size(0) != 1:
                 attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len)
-                combined_mask = attn_mask | key_padding_mask.unsqueeze(
-                    1
-                ).unsqueeze(2)
+                combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2)
             else:
                 # attn_mask.shape == (1, tgt_len, src_len)
-                combined_mask = attn_mask.unsqueeze(
-                    0
-                ) | key_padding_mask.unsqueeze(1).unsqueeze(2)
+                combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze(
+                    1
+                ).unsqueeze(2)
 
             attn_output_weights = attn_output_weights.view(
                 bsz, num_heads, tgt_len, src_len
             )
-            attn_output_weights = attn_output_weights.masked_fill(
-                combined_mask, 0.0
-            )
+            attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0)
             attn_output_weights = attn_output_weights.view(
                 bsz * num_heads, tgt_len, src_len
             )
@@ -1291,13 +1258,9 @@ class RelPositionMultiheadAttention(nn.Module):
         attn_output = torch.bmm(attn_output_weights, v)
         assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
         attn_output = (
-            attn_output.transpose(0, 1)
-            .contiguous()
-            .view(tgt_len, bsz, embed_dim)
-        )
-        attn_output = nn.functional.linear(
-            attn_output, out_proj_weight, out_proj_bias
+            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
         )
+        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
 
         if need_weights:
             # average attention weights over heads
@@ -1430,16 +1393,12 @@ class ConvolutionModule(nn.Module):
                 # manualy padding self.lorder zeros to the left
                 x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
             else:
-                assert (
-                    not self.training
-                ), "Cache should be None in training time"
+                assert not self.training, "Cache should be None in training time"
                 assert cache.size(0) == self.lorder
                 x = torch.cat([cache.permute(1, 2, 0), x], dim=2)
                 if right_context > 0:
                     cache = x.permute(2, 0, 1)[
-                        -(self.lorder + right_context) : (  # noqa
-                            -right_context
-                        ),
+                        -(self.lorder + right_context) : (-right_context),  # noqa
                         ...,
                     ]
                 else:
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
index 344e31283..3d66f9dc9 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
@@ -160,20 +160,24 @@ def get_parser():
         "--avg",
         type=int,
         default=15,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch' and '--iter'",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch' and '--iter'"
+        ),
     )
 
     parser.add_argument(
         "--use-averaged-model",
         type=str2bool,
         default=True,
-        help="Whether to load averaged model. Currently it only supports "
-        "using --epoch. If True, it would decode with the averaged model "
-        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
-        "Actually only the models with epoch number of `epoch-avg` and "
-        "`epoch` are loaded for averaging. ",
+        help=(
+            "Whether to load averaged model. Currently it only supports "
+            "using --epoch. If True, it would decode with the averaged model "
+            "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+            "Actually only the models with epoch number of `epoch-avg` and "
+            "`epoch` are loaded for averaging. "
+        ),
     )
 
     parser.add_argument(
@@ -244,8 +248,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -342,9 +345,7 @@ def decode_one_batch(
             simulate_streaming=True,
         )
     else:
-        encoder_out, encoder_out_lens = model.encoder(
-            x=feature, x_lens=feature_lens
-        )
+        encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
     hyps = []
 
@@ -360,10 +361,7 @@ def decode_one_batch(
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -409,11 +407,7 @@ def decode_one_batch(
         return {"greedy_search": hyps}
     elif params.decoding_method == "fast_beam_search":
         return {
-            (
-                f"beam_{params.beam}_"
-                f"max_contexts_{params.max_contexts}_"
-                f"max_states_{params.max_states}"
-            ): hyps
+            f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps
         }
     else:
         return {f"beam_size_{params.beam_size}": hyps}
@@ -484,9 +478,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -519,8 +511,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -589,13 +580,12 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for"
-                    f" --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg:
                 raise ValueError(
@@ -618,13 +608,12 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg + 1]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for"
-                    f" --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg + 1:
                 raise ValueError(
@@ -652,7 +641,7 @@ def main():
             filename_start = f"{params.exp_dir}/epoch-{start}.pt"
             filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
             logging.info(
-                f"Calculating the averaged model over epoch range from "
+                "Calculating the averaged model over epoch range from "
                 f"{start} (excluded) to {params.epoch}"
             )
             model.to(device)
@@ -720,8 +709,7 @@ def main():
         )
 
     dev_shards = [
-        str(path)
-        for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
+        str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
     ]
     cuts_dev_webdataset = CutSet.from_webdataset(
         dev_shards,
@@ -731,8 +719,7 @@ def main():
     )
 
     test_net_shards = [
-        str(path)
-        for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
+        str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
     ]
     cuts_test_net_webdataset = CutSet.from_webdataset(
         test_net_shards,
@@ -743,9 +730,7 @@ def main():
 
     test_meeting_shards = [
         str(path)
-        for path in sorted(
-            glob.glob(os.path.join(test_meeting, "shared-*.tar"))
-        )
+        for path in sorted(glob.glob(os.path.join(test_meeting, "shared-*.tar")))
     ]
     cuts_test_meeting_webdataset = CutSet.from_webdataset(
         test_meeting_shards,
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py
index 386248554..e522943c0 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py
@@ -75,9 +75,7 @@ class DecodeStream(object):
         # encoder.streaming_forward
         self.done_frames: int = 0
 
-        self.pad_length = (
-            params.right_context + 2
-        ) * params.subsampling_factor + 3
+        self.pad_length = (params.right_context + 2) * params.subsampling_factor + 3
 
         if params.decoding_method == "greedy_search":
             self.hyp = [params.blank_id] * params.context_size
@@ -91,13 +89,11 @@ class DecodeStream(object):
             )
         elif params.decoding_method == "fast_beam_search":
             # The rnnt_decoding_stream for fast_beam_search.
-            self.rnnt_decoding_stream: k2.RnntDecodingStream = (
-                k2.RnntDecodingStream(decoding_graph)
+            self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream(
+                decoding_graph
             )
         else:
-            raise ValueError(
-                f"Unsupported decoding method: {params.decoding_method}"
-            )
+            raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
 
     @property
     def done(self) -> bool:
@@ -126,13 +122,10 @@ class DecodeStream(object):
         """Consume chunk_size frames of features"""
         chunk_length = chunk_size + self.pad_length
 
-        ret_length = min(
-            self.num_frames - self.num_processed_frames, chunk_length
-        )
+        ret_length = min(self.num_frames - self.num_processed_frames, chunk_length)
 
         ret_features = self.features[
-            self.num_processed_frames : self.num_processed_frames  # noqa
-            + ret_length
+            self.num_processed_frames : self.num_processed_frames + ret_length  # noqa
         ]
 
         self.num_processed_frames += chunk_size
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
index d0a7fd69f..fb53f70ab 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
@@ -90,17 +90,20 @@ def get_parser():
         "--epoch",
         type=int,
         default=28,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=15,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -131,8 +134,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     add_model_arguments(parser)
 
@@ -201,9 +203,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py
index 1b064c874..9834189d8 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py
@@ -80,9 +80,11 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help="Path to the checkpoint. "
-        "The checkpoint is assumed to be saved by "
-        "icefall.checkpoint.save_checkpoint().",
+        help=(
+            "Path to the checkpoint. "
+            "The checkpoint is assumed to be saved by "
+            "icefall.checkpoint.save_checkpoint()."
+        ),
     )
 
     parser.add_argument(
@@ -107,10 +109,12 @@ def get_parser():
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     parser.add_argument(
@@ -157,8 +161,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -189,10 +192,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -253,9 +255,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -280,10 +280,7 @@ def main():
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -335,9 +332,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py
index 651aff6c9..810d94135 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py
@@ -173,14 +173,10 @@ def modified_beam_search(
         log_probs_shape = k2.ragged.create_ragged_shape2(
             row_splits=row_splits, cached_tot_size=log_probs.numel()
         )
-        ragged_log_probs = k2.RaggedTensor(
-            shape=log_probs_shape, value=log_probs
-        )
+        ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
 
         for i in range(batch_size):
-            topk_log_probs, topk_indexes = ragged_log_probs[i].topk(
-                num_active_paths
-            )
+            topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths)
 
             with warnings.catch_warnings():
                 warnings.simplefilter("ignore")
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py
index ff96c6487..31a7fe605 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py
@@ -119,20 +119,24 @@ def get_parser():
         "--avg",
         type=int,
         default=15,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch' and '--iter'",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch' and '--iter'"
+        ),
     )
 
     parser.add_argument(
         "--use-averaged-model",
         type=str2bool,
         default=True,
-        help="Whether to load averaged model. Currently it only supports "
-        "using --epoch. If True, it would decode with the averaged model "
-        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
-        "Actually only the models with epoch number of `epoch-avg` and "
-        "`epoch` are loaded for averaging. ",
+        help=(
+            "Whether to load averaged model. Currently it only supports "
+            "using --epoch. If True, it would decode with the averaged model "
+            "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+            "Actually only the models with epoch number of `epoch-avg` and "
+            "`epoch` are loaded for averaging. "
+        ),
     )
 
     parser.add_argument(
@@ -201,8 +205,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -311,9 +314,7 @@ def decode_one_chunk(
     encoder_out = model.joiner.encoder_proj(encoder_out)
 
     if params.decoding_method == "greedy_search":
-        greedy_search(
-            model=model, encoder_out=encoder_out, streams=decode_streams
-        )
+        greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
     elif params.decoding_method == "fast_beam_search":
         processed_lens = processed_lens + encoder_out_lens
         fast_beam_search_one_best(
@@ -333,9 +334,7 @@ def decode_one_chunk(
             num_active_paths=params.num_active_paths,
         )
     else:
-        raise ValueError(
-            f"Unsupported decoding method: {params.decoding_method}"
-        )
+        raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
 
     states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)]
 
@@ -389,9 +388,7 @@ def decode_dataset(
     decode_results = []
     # Contain decode streams currently running.
     decode_streams = []
-    initial_states = model.encoder.get_init_state(
-        params.left_context, device=device
-    )
+    initial_states = model.encoder.get_init_state(params.left_context, device=device)
     for num, cut in enumerate(cuts):
         # each utterance has a DecodeStream.
         decode_stream = DecodeStream(
@@ -461,9 +458,7 @@ def decode_dataset(
     elif params.decoding_method == "modified_beam_search":
         key = f"num_active_paths_{params.num_active_paths}"
     else:
-        raise ValueError(
-            f"Unsupported decoding method: {params.decoding_method}"
-        )
+        raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
 
     return {key: decode_results}
 
@@ -499,8 +494,7 @@ def save_results(
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+        params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
     )
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
@@ -565,13 +559,12 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for"
-                    f" --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg:
                 raise ValueError(
@@ -594,13 +587,12 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg + 1]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
             if len(filenames) == 0:
                 raise ValueError(
-                    f"No checkpoints found for"
-                    f" --iter {params.iter}, --avg {params.avg}"
+                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
                 )
             elif len(filenames) < params.avg + 1:
                 raise ValueError(
@@ -628,7 +620,7 @@ def main():
             filename_start = f"{params.exp_dir}/epoch-{start}.pt"
             filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
             logging.info(
-                f"Calculating the averaged model over epoch range from "
+                "Calculating the averaged model over epoch range from "
                 f"{start} (excluded) to {params.epoch}"
             )
             model.to(device)
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py
index 2052e9da7..40c9665f7 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py
@@ -98,9 +98,7 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[
-    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
-]
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
 
 
 def add_model_arguments(parser: argparse.ArgumentParser):
@@ -260,8 +258,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need "
-        "to be changed.",
+        help="The initial learning rate.  This value should not need to be changed.",
     )
 
     parser.add_argument(
@@ -284,42 +281,45 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
         "--prune-range",
         type=int,
         default=5,
-        help="The prune range for rnnt loss, it means how many symbols(context)"
-        "we are using to compute the loss",
+        help=(
+            "The prune range for rnnt loss, it means how many symbols(context)"
+            "we are using to compute the loss"
+        ),
     )
 
     parser.add_argument(
         "--lm-scale",
         type=float,
         default=0.25,
-        help="The scale to smooth the loss with lm "
-        "(output of prediction network) part.",
+        help=(
+            "The scale to smooth the loss with lm (output of prediction network) part."
+        ),
     )
 
     parser.add_argument(
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)"
-        "part.",
+        help="The scale to smooth the loss with am (output of encoder network)part.",
     )
 
     parser.add_argument(
         "--simple-loss-scale",
         type=float,
         default=0.5,
-        help="To get pruning ranges, we will calculate a simple version"
-        "loss(joiner is just addition), this simple loss also uses for"
-        "training (as a regularization item). We will scale the simple loss"
-        "with this parameter before adding to the final loss.",
+        help=(
+            "To get pruning ranges, we will calculate a simple version"
+            "loss(joiner is just addition), this simple loss also uses for"
+            "training (as a regularization item). We will scale the simple loss"
+            "with this parameter before adding to the final loss."
+        ),
     )
 
     parser.add_argument(
@@ -665,11 +665,7 @@ def compute_loss(
      warmup: a floating point value which increases throughout training;
         values >= 1.0 are fully warmed up and have all modules present.
     """
-    device = (
-        model.device
-        if isinstance(model, DDP)
-        else next(model.parameters()).device
-    )
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
     feature = batch["inputs"]
     # at entry, feature is (N, T, C)
     assert feature.ndim == 3
@@ -701,23 +697,16 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0
-            if warmup < 1.0
-            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
-        )
-        loss = (
-            params.simple_loss_scale * simple_loss
-            + pruned_loss_scale * pruned_loss
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
         )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
 
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -841,9 +830,7 @@ def train_one_epoch(
             scaler.update()
             optimizer.zero_grad()
         except:  # noqa
-            display_and_save_batch(
-                batch, params=params, graph_compiler=graph_compiler
-            )
+            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
             raise
 
         if params.print_diagnostics and batch_idx == 5:
@@ -901,9 +888,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -1016,7 +1001,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2 ** 22
+            2**22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
@@ -1184,9 +1169,7 @@ def scan_pessimistic_batches_for_oom(
                     f"Failing criterion: {criterion} "
                     f"(={crit_values[criterion]}) ..."
                 )
-            display_and_save_batch(
-                batch, params=params, graph_compiler=graph_compiler
-            )
+            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
             raise
 
 
diff --git a/egs/yesno/ASR/local/compile_hlg.py b/egs/yesno/ASR/local/compile_hlg.py
index f83be05cf..7234ca929 100755
--- a/egs/yesno/ASR/local/compile_hlg.py
+++ b/egs/yesno/ASR/local/compile_hlg.py
@@ -128,9 +128,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/yesno/ASR/local/compute_fbank_yesno.py b/egs/yesno/ASR/local/compute_fbank_yesno.py
index 9a4e8a36f..75d95df68 100755
--- a/egs/yesno/ASR/local/compute_fbank_yesno.py
+++ b/egs/yesno/ASR/local/compute_fbank_yesno.py
@@ -54,9 +54,7 @@ def compute_fbank_yesno():
         dataset_parts,
     )
 
-    extractor = Fbank(
-        FbankConfig(sampling_rate=8000, num_mel_bins=num_mel_bins)
-    )
+    extractor = Fbank(FbankConfig(sampling_rate=8000, num_mel_bins=num_mel_bins))
 
     with get_executor() as ex:  # Initialize the executor only once.
         for partition, m in manifests.items():
@@ -71,9 +69,7 @@ def compute_fbank_yesno():
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set
-                    + cut_set.perturb_speed(0.9)
-                    + cut_set.perturb_speed(1.1)
+                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -87,9 +83,7 @@ def compute_fbank_yesno():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/yesno/ASR/tdnn/asr_datamodule.py b/egs/yesno/ASR/tdnn/asr_datamodule.py
index 85e5f1358..21860d2f5 100644
--- a/egs/yesno/ASR/tdnn/asr_datamodule.py
+++ b/egs/yesno/ASR/tdnn/asr_datamodule.py
@@ -56,10 +56,12 @@ class YesNoAsrDataModule(DataModule):
         super().add_arguments(parser)
         group = parser.add_argument_group(
             title="ASR data related options",
-            description="These options are used for the preparation of "
-            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
-            "effective batch sizes, sampling strategies, applied data "
-            "augmentations, etc.",
+            description=(
+                "These options are used for the preparation of "
+                "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+                "effective batch sizes, sampling strategies, applied data "
+                "augmentations, etc."
+            ),
         )
         group.add_argument(
             "--feature-dir",
@@ -71,75 +73,91 @@ class YesNoAsrDataModule(DataModule):
             "--max-duration",
             type=int,
             default=30.0,
-            help="Maximum pooled recordings duration (seconds) in a "
-            "single batch. You can reduce it if it causes CUDA OOM.",
+            help=(
+                "Maximum pooled recordings duration (seconds) in a "
+                "single batch. You can reduce it if it causes CUDA OOM."
+            ),
         )
         group.add_argument(
             "--bucketing-sampler",
             type=str2bool,
             default=False,
-            help="When enabled, the batches will come from buckets of "
-            "similar duration (saves padding frames).",
+            help=(
+                "When enabled, the batches will come from buckets of "
+                "similar duration (saves padding frames)."
+            ),
         )
         group.add_argument(
             "--num-buckets",
             type=int,
             default=10,
-            help="The number of buckets for the DynamicBucketingSampler"
-            "(you might want to increase it for larger datasets).",
+            help=(
+                "The number of buckets for the DynamicBucketingSampler"
+                "(you might want to increase it for larger datasets)."
+            ),
         )
         group.add_argument(
             "--concatenate-cuts",
             type=str2bool,
             default=False,
-            help="When enabled, utterances (cuts) will be concatenated "
-            "to minimize the amount of padding.",
+            help=(
+                "When enabled, utterances (cuts) will be concatenated "
+                "to minimize the amount of padding."
+            ),
         )
         group.add_argument(
             "--duration-factor",
             type=float,
             default=1.0,
-            help="Determines the maximum duration of a concatenated cut "
-            "relative to the duration of the longest cut in a batch.",
+            help=(
+                "Determines the maximum duration of a concatenated cut "
+                "relative to the duration of the longest cut in a batch."
+            ),
         )
         group.add_argument(
             "--gap",
             type=float,
             default=1.0,
-            help="The amount of padding (in seconds) inserted between "
-            "concatenated cuts. This padding is filled with noise when "
-            "noise augmentation is used.",
+            help=(
+                "The amount of padding (in seconds) inserted between "
+                "concatenated cuts. This padding is filled with noise when "
+                "noise augmentation is used."
+            ),
         )
         group.add_argument(
             "--on-the-fly-feats",
             type=str2bool,
             default=False,
-            help="When enabled, use on-the-fly cut mixing and feature "
-            "extraction. Will drop existing precomputed feature manifests "
-            "if available.",
+            help=(
+                "When enabled, use on-the-fly cut mixing and feature "
+                "extraction. Will drop existing precomputed feature manifests "
+                "if available."
+            ),
         )
         group.add_argument(
             "--shuffle",
             type=str2bool,
             default=True,
-            help="When enabled (=default), the examples will be "
-            "shuffled for each epoch.",
+            help=(
+                "When enabled (=default), the examples will be shuffled for each epoch."
+            ),
         )
         group.add_argument(
             "--return-cuts",
             type=str2bool,
             default=True,
-            help="When enabled, each batch will have the "
-            "field: batch['supervisions']['cut'] with the cuts that "
-            "were used to construct it.",
+            help=(
+                "When enabled, each batch will have the "
+                "field: batch['supervisions']['cut'] with the cuts that "
+                "were used to construct it."
+            ),
         )
 
         group.add_argument(
             "--num-workers",
             type=int,
             default=2,
-            help="The number of training dataloader workers that "
-            "collect the batches.",
+            help="The number of training dataloader workers that collect the batches.",
         )
 
     def train_dataloaders(self) -> DataLoader:
@@ -150,7 +168,7 @@ class YesNoAsrDataModule(DataModule):
         transforms = []
         if self.args.concatenate_cuts:
             logging.info(
-                f"Using cut concatenation with duration factor "
+                "Using cut concatenation with duration factor "
                 f"{self.args.duration_factor} and gap {self.args.gap}."
             )
             # Cut concatenation should be the first transform in the list,
diff --git a/egs/yesno/ASR/tdnn/decode.py b/egs/yesno/ASR/tdnn/decode.py
index 9d4ab4b61..41afe0404 100755
--- a/egs/yesno/ASR/tdnn/decode.py
+++ b/egs/yesno/ASR/tdnn/decode.py
@@ -35,16 +35,19 @@ def get_parser():
         "--epoch",
         type=int,
         default=14,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=2,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -201,9 +204,7 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -274,9 +275,7 @@ def main():
 
     logging.info(f"device: {device}")
 
-    HLG = k2.Fsa.from_dict(
-        torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
-    )
+    HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
     HLG = HLG.to(device)
     assert HLG.requires_grad is False
 
@@ -297,9 +296,7 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save(
-            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
-        )
+        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
         return
 
     model.to(device)
@@ -317,9 +314,7 @@ def main():
         word_table=lexicon.word_table,
     )
 
-    save_results(
-        exp_dir=params.exp_dir, test_set_name="test_set", results=results
-    )
+    save_results(exp_dir=params.exp_dir, test_set_name="test_set", results=results)
 
     logging.info("Done!")
 
diff --git a/egs/yesno/ASR/tdnn/pretrained.py b/egs/yesno/ASR/tdnn/pretrained.py
index 14220be19..09a8672ae 100755
--- a/egs/yesno/ASR/tdnn/pretrained.py
+++ b/egs/yesno/ASR/tdnn/pretrained.py
@@ -41,9 +41,11 @@ def get_parser():
         "--checkpoint",
         type=str,
         required=True,
-        help="Path to the checkpoint. "
-        "The checkpoint is assumed to be saved by "
-        "icefall.checkpoint.save_checkpoint().",
+        help=(
+            "Path to the checkpoint. "
+            "The checkpoint is assumed to be saved by "
+            "icefall.checkpoint.save_checkpoint()."
+        ),
     )
 
     parser.add_argument(
@@ -53,18 +55,18 @@ def get_parser():
         help="Path to words.txt",
     )
 
-    parser.add_argument(
-        "--HLG", type=str, required=True, help="Path to HLG.pt."
-    )
+    parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
 
     parser.add_argument(
         "sound_files",
         type=str,
         nargs="+",
-        help="The input sound file(s) to transcribe. "
-        "Supported formats are those supported by torchaudio.load(). "
-        "For example, wav and flac are supported. "
-        "The sample rate has to be 16kHz.",
+        help=(
+            "The input sound file(s) to transcribe. "
+            "Supported formats are those supported by torchaudio.load(). "
+            "For example, wav and flac are supported. "
+            "The sample rate has to be 16kHz."
+        ),
     )
 
     return parser
@@ -101,10 +103,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -159,9 +160,7 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     # Note: We don't use key padding mask for attention during decoding
     with torch.no_grad():
@@ -201,9 +200,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/yesno/ASR/tdnn/train.py b/egs/yesno/ASR/tdnn/train.py
index f32a27f35..335493491 100755
--- a/egs/yesno/ASR/tdnn/train.py
+++ b/egs/yesno/ASR/tdnn/train.py
@@ -430,9 +430,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/egs/yesno/ASR/transducer/decode.py b/egs/yesno/ASR/transducer/decode.py
index 6714180db..de478334e 100755
--- a/egs/yesno/ASR/transducer/decode.py
+++ b/egs/yesno/ASR/transducer/decode.py
@@ -48,16 +48,19 @@ def get_parser():
         "--epoch",
         type=int,
         default=125,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=20,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
     parser.add_argument(
         "--exp-dir",
@@ -116,9 +119,7 @@ def decode_one_batch(
     # at entry, feature is (N, T, C)
     feature_lens = batch["supervisions"]["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
     hyps = []
     batch_size = encoder_out.size(0)
@@ -186,9 +187,7 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -303,9 +302,7 @@ def main():
         model=model,
     )
 
-    save_results(
-        exp_dir=params.exp_dir, test_set_name="test_set", results=results
-    )
+    save_results(exp_dir=params.exp_dir, test_set_name="test_set", results=results)
 
     logging.info("Done!")
 
diff --git a/egs/yesno/ASR/transducer/train.py b/egs/yesno/ASR/transducer/train.py
index deb92107d..88866ae81 100755
--- a/egs/yesno/ASR/transducer/train.py
+++ b/egs/yesno/ASR/transducer/train.py
@@ -430,9 +430,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/icefall/char_graph_compiler.py b/icefall/char_graph_compiler.py
index 235160e14..c31db6e4c 100644
--- a/icefall/char_graph_compiler.py
+++ b/icefall/char_graph_compiler.py
@@ -71,9 +71,7 @@ class CharCtcTrainingGraphCompiler(object):
         for text in texts:
             text = re.sub(whitespace, "", text)
             sub_ids = [
-                self.token_table[txt]
-                if txt in self.token_table
-                else self.oov_id
+                self.token_table[txt] if txt in self.token_table else self.oov_id
                 for txt in text
             ]
             ids.append(sub_ids)
@@ -96,9 +94,7 @@ class CharCtcTrainingGraphCompiler(object):
         for text in texts:
             text = text.split("/")
             sub_ids = [
-                self.token_table[txt]
-                if txt in self.token_table
-                else self.oov_id
+                self.token_table[txt] if txt in self.token_table else self.oov_id
                 for txt in text
             ]
             ids.append(sub_ids)
diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py
index 5069b78e8..8aa0a8eeb 100644
--- a/icefall/checkpoint.py
+++ b/icefall/checkpoint.py
@@ -292,15 +292,11 @@ def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
     """
     checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt"))
     pattern = re.compile(r"checkpoint-([0-9]+).pt")
-    iter_checkpoints = [
-        (int(pattern.search(c).group(1)), c) for c in checkpoints
-    ]
+    iter_checkpoints = [(int(pattern.search(c).group(1)), c) for c in checkpoints]
     # iter_checkpoints is a list of tuples. Each tuple contains
     # two elements: (iteration_number, checkpoint-iteration_number.pt)
 
-    iter_checkpoints = sorted(
-        iter_checkpoints, reverse=True, key=lambda x: x[0]
-    )
+    iter_checkpoints = sorted(iter_checkpoints, reverse=True, key=lambda x: x[0])
     if iteration >= 0:
         ans = [ic[1] for ic in iter_checkpoints if ic[0] >= iteration]
     else:
@@ -469,7 +465,5 @@ def average_state_dict(
         v = state_dict_1[k]
         if torch.is_floating_point(v):
             v *= weight_1
-            v += (
-                state_dict_2[k].to(device=state_dict_1[k].device) * weight_2
-            )
+            v += state_dict_2[k].to(device=state_dict_1[k].device) * weight_2
             v *= scaling_factor
diff --git a/icefall/decode.py b/icefall/decode.py
index f04ee368c..6cd87bdc0 100644
--- a/icefall/decode.py
+++ b/icefall/decode.py
@@ -334,13 +334,9 @@ class Nbest(object):
         if hasattr(lattice, "aux_labels"):
             # delete token IDs as it is not needed
             del word_fsa.aux_labels
-            word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(
-                word_fsa
-            )
+            word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa)
         else:
-            word_fsa_with_epsilon_loops = k2.linear_fst_with_self_loops(
-                word_fsa
-            )
+            word_fsa_with_epsilon_loops = k2.linear_fst_with_self_loops(word_fsa)
 
         path_to_utt_map = self.shape.row_ids(1)
 
@@ -370,9 +366,7 @@ class Nbest(object):
         # path_lattice has word IDs as labels and token IDs as aux_labels
         path_lattice = k2.top_sort(k2.connect(path_lattice))
 
-        one_best = k2.shortest_path(
-            path_lattice, use_double_scores=use_double_scores
-        )
+        one_best = k2.shortest_path(path_lattice, use_double_scores=use_double_scores)
 
         one_best = k2.invert(one_best)
         # Now one_best has token IDs as labels and word IDs as aux_labels
@@ -442,9 +436,7 @@ class Nbest(object):
         scores_shape = self.fsa.arcs.shape().remove_axis(1)
         # scores_shape has axes [path][arc]
 
-        ragged_scores = k2.RaggedTensor(
-            scores_shape, self.fsa.scores.contiguous()
-        )
+        ragged_scores = k2.RaggedTensor(scores_shape, self.fsa.scores.contiguous())
 
         tot_scores = ragged_scores.sum()
 
@@ -678,9 +670,7 @@ def rescore_with_n_best_list(
             logging.info(f"num_paths before decreasing: {num_paths}")
             num_paths = int(num_paths / 2)
             if loop_count >= max_loop_count or num_paths <= 0:
-                logging.info(
-                    "Return None as the resulting lattice is too large."
-                )
+                logging.info("Return None as the resulting lattice is too large.")
                 return None
             logging.info(
                 "This OOM is not an error. You can ignore it. "
@@ -787,13 +777,9 @@ def rescore_with_whole_lattice(
         except RuntimeError as e:
             logging.info(f"Caught exception:\n{e}\n")
             if loop_count >= max_loop_count:
-                logging.info(
-                    "Return None as the resulting lattice is too large."
-                )
+                logging.info("Return None as the resulting lattice is too large.")
                 return None
-            logging.info(
-                f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}"
-            )
+            logging.info(f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}")
             logging.info(
                 "This OOM is not an error. You can ignore it. "
                 "If your model does not converge well, or --max-duration "
@@ -805,9 +791,7 @@ def rescore_with_whole_lattice(
                 prune_th_list[loop_count],
                 True,
             )
-            logging.info(
-                f"num_arcs after pruning: {inv_lattice.arcs.num_elements()}"
-            )
+            logging.info(f"num_arcs after pruning: {inv_lattice.arcs.num_elements()}")
         loop_count += 1
 
     # lat has token IDs as labels
@@ -894,9 +878,7 @@ def rescore_with_attention_decoder(
             logging.info(f"num_paths before decreasing: {num_paths}")
             num_paths = int(num_paths / 2)
             if loop_count >= max_loop_count or num_paths <= 0:
-                logging.info(
-                    "Return None as the resulting lattice is too large."
-                )
+                logging.info("Return None as the resulting lattice is too large.")
                 return None
             logging.info(
                 "This OOM is not an error. You can ignore it. "
diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py
index b075aceac..7b58ffbd4 100644
--- a/icefall/diagnostics.py
+++ b/icefall/diagnostics.py
@@ -19,7 +19,7 @@
 
 import random
 from dataclasses import dataclass
-from typing import Optional, Tuple, List
+from typing import List, Optional, Tuple
 
 import torch
 from torch import Tensor, nn
@@ -78,11 +78,11 @@ def get_tensor_stats(
     elif stats_type == "abs":
         x = x.abs()
     elif stats_type == "rms":
-        x = x ** 2
+        x = x**2
     elif stats_type == "positive":
         x = (x > 0).to(dtype=torch.float)
     else:
-        assert stats_type in [ "value", "max", "min" ]
+        assert stats_type in ["value", "max", "min"]
 
     sum_dims = [d for d in range(x.ndim) if d != dim]
     if len(sum_dims) > 0:
@@ -121,7 +121,9 @@ class TensorDiagnostic(object):
         self.name = name
         self.class_name = None  # will assign in accumulate()
 
-        self.stats = None  # we'll later assign a list to this data member.  It's a list of dict.
+        self.stats = (
+            None  # we'll later assign a list to this data member.  It's a list of dict.
+        )
 
         # the keys into self.stats[dim] are strings, whose values can be
         # "abs", "max", "min" ,"value", "positive", "rms", "value".
@@ -133,7 +135,6 @@ class TensorDiagnostic(object):
         # only adding a new element to the list if there was a different dim.
         # if the string in the key is "eigs", if we detect a length mismatch we put None as the value.
 
-
     def accumulate(self, x, class_name: Optional[str] = None):
         """
         Accumulate tensors.
@@ -185,17 +186,12 @@ class TensorDiagnostic(object):
                         done = True
                         break
                 if not done:
-                    if (
-                        this_dim_stats[stats_type] != []
-                        and stats_type == "eigs"
-                    ):
+                    if this_dim_stats[stats_type] != [] and stats_type == "eigs":
                         # >1 size encountered on this dim, e.g. it's a batch or time dimension,
                         # don't accumulat "eigs" stats type, it uses too much memory
                         this_dim_stats[stats_type] = None
                     else:
-                        this_dim_stats[stats_type].append(
-                            TensorAndCount(stats, count)
-                        )
+                        this_dim_stats[stats_type].append(TensorAndCount(stats, count))
 
     def print_diagnostics(self):
         """Print diagnostics for each dimension of the tensor."""
@@ -211,7 +207,6 @@ class TensorDiagnostic(object):
                     assert stats_type == "eigs"
                     continue
 
-
                 def get_count(count):
                     return 1 if stats_type in ["max", "min"] else count
 
@@ -221,7 +216,8 @@ class TensorDiagnostic(object):
                     # a dimension that has variable size in different nnet
                     # forwards, e.g. a time dimension in an ASR model.
                     stats = torch.cat(
-                        [x.tensor / get_count(x.count) for x in stats_list], dim=0
+                        [x.tensor / get_count(x.count) for x in stats_list],
+                        dim=0,
                     )
 
                 if stats_type == "eigs":
@@ -229,9 +225,7 @@ class TensorDiagnostic(object):
                         eigs, _ = torch.symeig(stats)
                         stats = eigs.abs().sqrt()
                     except:  # noqa
-                        print(
-                            "Error getting eigenvalues, trying another method."
-                        )
+                        print("Error getting eigenvalues, trying another method.")
                         eigs, _ = torch.eig(stats)
                         stats = eigs.abs().sqrt()
                         # sqrt so it reflects data magnitude, like stddev- not variance
@@ -242,9 +236,9 @@ class TensorDiagnostic(object):
 
                 # if `summarize` we print percentiles of the stats; else,
                 # we print out individual elements.
-                summarize = (
-                    len(stats_list) > 1
-                ) or self.opts.dim_is_summarized(stats.numel())
+                summarize = (len(stats_list) > 1) or self.opts.dim_is_summarized(
+                    stats.numel()
+                )
                 if summarize:  # usually `summarize` will be true
                     # print out percentiles.
                     stats = stats.sort()[0]
@@ -261,15 +255,15 @@ class TensorDiagnostic(object):
                     ans = stats.tolist()
                     ans = ["%.2g" % x for x in ans]
                     ans = "[" + " ".join(ans) + "]"
-                if stats_type in [ "value", "rms", "eigs" ]:
+                if stats_type in ["value", "rms", "eigs"]:
                     # This norm is useful because it is strictly less than the largest
                     # sqrt(eigenvalue) of the variance, which we print out, and shows,
                     # speaking in an approximate way, how much of that largest eigenvalue
                     # can be attributed to the mean of the distribution.
-                    norm = (stats ** 2).sum().sqrt().item()
+                    norm = (stats**2).sum().sqrt().item()
                     ans += f", norm={norm:.2g}"
                 mean = stats.mean().item()
-                rms = (stats ** 2).mean().sqrt().item()
+                rms = (stats**2).mean().sqrt().item()
                 ans += f", mean={mean:.3g}, rms={rms:.3g}"
 
                 # OK, "ans" contains the actual stats, e.g.
@@ -277,17 +271,17 @@ class TensorDiagnostic(object):
 
                 sizes = [x.tensor.shape[0] for x in stats_list]
                 size_str = (
-                    f"{sizes[0]}"
-                    if len(sizes) == 1
-                    else f"{min(sizes)}..{max(sizes)}"
+                    f"{sizes[0]}" if len(sizes) == 1 else f"{min(sizes)}..{max(sizes)}"
+                )
+                maybe_class_name = (
+                    f" type={self.class_name}," if self.class_name is not None else ""
                 )
-                maybe_class_name = f" type={self.class_name}," if self.class_name is not None else ""
                 print(
-                    f"module={self.name},{maybe_class_name} dim={dim}, size={size_str}, {stats_type} {ans}"
+                    f"module={self.name},{maybe_class_name} dim={dim}, size={size_str},"
+                    f" {stats_type} {ans}"
                 )
 
 
-
 class ModelDiagnostic(object):
     """This class stores diagnostics for all tensors in the torch.nn.Module.
 
@@ -345,32 +339,32 @@ def attach_diagnostics(
         # (matters for name, since the variable gets overwritten).
         # These closures don't really capture by value, only by
         # "the final value the variable got in the function" :-(
-        def forward_hook(
-            _module, _input, _output, _model_diagnostic=ans, _name=name
-        ):
+        def forward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
             if isinstance(_output, tuple) and len(_output) == 1:
                 _output = _output[0]
 
             if isinstance(_output, Tensor):
-                _model_diagnostic[f"{_name}.output"].accumulate(_output,
-                                                                class_name=type(_module).__name__)
+                _model_diagnostic[f"{_name}.output"].accumulate(
+                    _output, class_name=type(_module).__name__
+                )
             elif isinstance(_output, tuple):
                 for i, o in enumerate(_output):
-                    _model_diagnostic[f"{_name}.output[{i}]"].accumulate(o,
-                                                                         class_name=type(_module).__name__)
+                    _model_diagnostic[f"{_name}.output[{i}]"].accumulate(
+                        o, class_name=type(_module).__name__
+                    )
 
-        def backward_hook(
-            _module, _input, _output, _model_diagnostic=ans, _name=name
-        ):
+        def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
             if isinstance(_output, tuple) and len(_output) == 1:
                 _output = _output[0]
             if isinstance(_output, Tensor):
-                _model_diagnostic[f"{_name}.grad"].accumulate(_output,
-                                                              class_name=type(_module).__name__)
+                _model_diagnostic[f"{_name}.grad"].accumulate(
+                    _output, class_name=type(_module).__name__
+                )
             elif isinstance(_output, tuple):
                 for i, o in enumerate(_output):
-                    _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o,
-                                                                       class_name=type(_module).__name__)
+                    _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(
+                        o, class_name=type(_module).__name__
+                    )
 
         module.register_forward_hook(forward_hook)
         module.register_backward_hook(backward_hook)
diff --git a/icefall/dist.py b/icefall/dist.py
index 7016beafb..9df1c5bd1 100644
--- a/icefall/dist.py
+++ b/icefall/dist.py
@@ -29,9 +29,7 @@ def setup_dist(rank, world_size, master_port=None, use_ddp_launch=False):
         os.environ["MASTER_ADDR"] = "localhost"
 
     if "MASTER_PORT" not in os.environ:
-        os.environ["MASTER_PORT"] = (
-            "12354" if master_port is None else str(master_port)
-        )
+        os.environ["MASTER_PORT"] = "12354" if master_port is None else str(master_port)
 
     if use_ddp_launch is False:
         dist.init_process_group("nccl", rank=rank, world_size=world_size)
diff --git a/icefall/env.py b/icefall/env.py
index 8aeda6be2..373e9a9ff 100644
--- a/icefall/env.py
+++ b/icefall/env.py
@@ -53,9 +53,7 @@ def get_git_sha1():
             )
             > 0
         )
-        git_commit = (
-            git_commit + "-dirty" if dirty_commit else git_commit + "-clean"
-        )
+        git_commit = git_commit + "-dirty" if dirty_commit else git_commit + "-clean"
     except:  # noqa
         return None
 
diff --git a/icefall/graph_compiler.py b/icefall/graph_compiler.py
index 570ed7d7a..e2ff03f61 100644
--- a/icefall/graph_compiler.py
+++ b/icefall/graph_compiler.py
@@ -75,9 +75,7 @@ class CtcTrainingGraphCompiler(object):
 
         # NOTE: k2.compose runs on CUDA only when treat_epsilons_specially
         # is False, so we add epsilon self-loops here
-        fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops(
-            transcript_fsa
-        )
+        fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops(transcript_fsa)
 
         fsa_with_self_loops = k2.arc_sort(fsa_with_self_loops)
 
diff --git a/icefall/hooks.py b/icefall/hooks.py
index fbcf5e148..398a5f689 100644
--- a/icefall/hooks.py
+++ b/icefall/hooks.py
@@ -14,10 +14,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import logging
 import random
+
 import torch
 from torch import Tensor, nn
-import logging
 
 
 def register_inf_check_hooks(model: nn.Module) -> None:
@@ -56,7 +57,7 @@ def register_inf_check_hooks(model: nn.Module) -> None:
             if isinstance(_output, Tensor):
                 if not torch.isfinite(_output.to(torch.float32).sum()):
                     logging.warning(
-                        f"The sum of {_name}.grad is not finite" # ": {_output}"
+                        f"The sum of {_name}.grad is not finite"  # ": {_output}"
                     )
             elif isinstance(_output, tuple):
                 for i, o in enumerate(_output):
@@ -65,28 +66,20 @@ def register_inf_check_hooks(model: nn.Module) -> None:
                     if not isinstance(o, Tensor):
                         continue
                     if not torch.isfinite(o.to(torch.float32).sum()):
-                        logging.warning(
-                            f"The sum of {_name}.grad[{i}] is not finite"
-                        )
+                        logging.warning(f"The sum of {_name}.grad[{i}] is not finite")
 
         module.register_forward_hook(forward_hook)
         module.register_backward_hook(backward_hook)
 
-
     for name, parameter in model.named_parameters():
 
-        def param_backward_hook(
-                grad, _name=name
-        ):
+        def param_backward_hook(grad, _name=name):
             if not torch.isfinite(grad.to(torch.float32).sum()):
-                logging.warning(
-                    f"The sum of {_name}.param_grad is not finite"
-                )
+                logging.warning(f"The sum of {_name}.param_grad is not finite")
 
         parameter.register_hook(param_backward_hook)
 
 
-
 def _test_inf_check_hooks():
     model = nn.Sequential(nn.Linear(100, 50), nn.Linear(50, 80))
 
diff --git a/icefall/lexicon.py b/icefall/lexicon.py
index 80bd7c1ee..22e1b78bb 100644
--- a/icefall/lexicon.py
+++ b/icefall/lexicon.py
@@ -49,18 +49,12 @@ def read_lexicon(filename: str) -> List[Tuple[str, List[str]]]:
                 continue
 
             if len(a) < 2:
-                logging.info(
-                    f"Found bad line {line} in lexicon file {filename}"
-                )
-                logging.info(
-                    "Every line is expected to contain at least 2 fields"
-                )
+                logging.info(f"Found bad line {line} in lexicon file {filename}")
+                logging.info("Every line is expected to contain at least 2 fields")
                 sys.exit(1)
             word = a[0]
             if word == "":
-                logging.info(
-                    f"Found bad line {line} in lexicon file {filename}"
-                )
+                logging.info(f"Found bad line {line} in lexicon file {filename}")
                 logging.info(" should not be a valid word")
                 sys.exit(1)
 
@@ -119,9 +113,7 @@ def convert_lexicon_to_ragged(
     lexicon_tmp = read_lexicon(filename)
     lexicon = dict(lexicon_tmp)
     if len(lexicon_tmp) != len(lexicon):
-        raise RuntimeError(
-            "It's assumed that each word has a unique pronunciation"
-        )
+        raise RuntimeError("It's assumed that each word has a unique pronunciation")
 
     for i in range(disambig_id):
         w = word_table[i]
diff --git a/icefall/mmi.py b/icefall/mmi.py
index 2c479fc2c..16ed6e032 100644
--- a/icefall/mmi.py
+++ b/icefall/mmi.py
@@ -63,10 +63,7 @@ def _compute_mmi_loss_exact_optimized(
 
     # [0, num_fsas, 1, num_fsas, 2, num_fsas, ... ]
     num_den_graphs_indexes = (
-        torch.stack([num_graphs_indexes, den_graphs_indexes])
-        .t()
-        .reshape(-1)
-        .to(device)
+        torch.stack([num_graphs_indexes, den_graphs_indexes]).t().reshape(-1).to(device)
     )
 
     num_den_reordered_graphs = k2.index(num_den_graphs, num_den_graphs_indexes)
@@ -115,20 +112,12 @@ def _compute_mmi_loss_exact_non_optimized(
     num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=True)
 
     # TODO: pass output_beam as function argument
-    num_lats = k2.intersect_dense(
-        num_graphs, dense_fsa_vec, output_beam=beam_size
-    )
-    den_lats = k2.intersect_dense(
-        den_graphs, dense_fsa_vec, output_beam=beam_size
-    )
+    num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=beam_size)
+    den_lats = k2.intersect_dense(den_graphs, dense_fsa_vec, output_beam=beam_size)
 
-    num_tot_scores = num_lats.get_tot_scores(
-        log_semiring=True, use_double_scores=True
-    )
+    num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
 
-    den_tot_scores = den_lats.get_tot_scores(
-        log_semiring=True, use_double_scores=True
-    )
+    den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
 
     tot_scores = num_tot_scores - den_scale * den_tot_scores
 
@@ -168,13 +157,9 @@ def _compute_mmi_loss_pruned(
         max_active_states=10000,
     )
 
-    num_tot_scores = num_lats.get_tot_scores(
-        log_semiring=True, use_double_scores=True
-    )
+    num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
 
-    den_tot_scores = den_lats.get_tot_scores(
-        log_semiring=True, use_double_scores=True
-    )
+    den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
 
     tot_scores = num_tot_scores - den_scale * den_tot_scores
 
diff --git a/icefall/mmi_graph_compiler.py b/icefall/mmi_graph_compiler.py
index 0d901227d..9f680f83d 100644
--- a/icefall/mmi_graph_compiler.py
+++ b/icefall/mmi_graph_compiler.py
@@ -137,9 +137,7 @@ class MmiTrainingGraphCompiler(object):
             transcript_fsa
         )
 
-        transcript_fsa_with_self_loops = k2.arc_sort(
-            transcript_fsa_with_self_loops
-        )
+        transcript_fsa_with_self_loops = k2.arc_sort(transcript_fsa_with_self_loops)
 
         num = k2.compose(
             self.ctc_topo_P,
@@ -155,9 +153,7 @@ class MmiTrainingGraphCompiler(object):
 
         ctc_topo_P_vec = k2.create_fsa_vec([self.ctc_topo_P])
         if replicate_den:
-            indexes = torch.zeros(
-                len(texts), dtype=torch.int32, device=self.device
-            )
+            indexes = torch.zeros(len(texts), dtype=torch.int32, device=self.device)
             den = k2.index_fsa(ctc_topo_P_vec, indexes)
         else:
             den = ctc_topo_P_vec
diff --git a/icefall/rnn_lm/compute_perplexity.py b/icefall/rnn_lm/compute_perplexity.py
index 550801a8f..9a275bf28 100755
--- a/icefall/rnn_lm/compute_perplexity.py
+++ b/icefall/rnn_lm/compute_perplexity.py
@@ -46,16 +46,19 @@ def get_parser():
         "--epoch",
         type=int,
         default=49,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
     parser.add_argument(
         "--avg",
         type=int,
         default=20,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -194,7 +197,7 @@ def main():
 
     logging.info(f"Number of model parameters: {num_param}")
     logging.info(
-        f"Number of model parameters (requires_grad): "
+        "Number of model parameters (requires_grad): "
         f"{num_param_requires_grad} "
         f"({num_param_requires_grad/num_param_requires_grad*100}%)"
     )
diff --git a/icefall/rnn_lm/dataset.py b/icefall/rnn_lm/dataset.py
index 598e329c4..4bf982503 100644
--- a/icefall/rnn_lm/dataset.py
+++ b/icefall/rnn_lm/dataset.py
@@ -155,12 +155,8 @@ class LmDatasetCollate:
         sentence_tokens_with_sos = add_sos(sentence_tokens, self.sos_id)
         sentence_tokens_with_eos = add_eos(sentence_tokens, self.eos_id)
 
-        x = sentence_tokens_with_sos.pad(
-            mode="constant", padding_value=self.blank_id
-        )
-        y = sentence_tokens_with_eos.pad(
-            mode="constant", padding_value=self.blank_id
-        )
+        x = sentence_tokens_with_sos.pad(mode="constant", padding_value=self.blank_id)
+        y = sentence_tokens_with_eos.pad(mode="constant", padding_value=self.blank_id)
         sentence_token_lengths += 1  # plus 1 since we added a SOS
 
         return x.to(torch.int64), y.to(torch.int64), sentence_token_lengths
diff --git a/icefall/rnn_lm/export.py b/icefall/rnn_lm/export.py
index 094035fce..2e878f5c8 100644
--- a/icefall/rnn_lm/export.py
+++ b/icefall/rnn_lm/export.py
@@ -38,17 +38,20 @@ def get_parser():
         "--epoch",
         type=int,
         default=29,
-        help="It specifies the checkpoint to use for decoding."
-        "Note: Epoch counts from 0.",
+        help=(
+            "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0."
+        ),
     )
 
     parser.add_argument(
         "--avg",
         type=int,
         default=5,
-        help="Number of checkpoints to average. Automatically select "
-        "consecutive checkpoints before the checkpoint specified by "
-        "'--epoch'. ",
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch'. "
+        ),
     )
 
     parser.add_argument(
@@ -159,9 +162,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/icefall/rnn_lm/model.py b/icefall/rnn_lm/model.py
index a6144727a..9eef88840 100644
--- a/icefall/rnn_lm/model.py
+++ b/icefall/rnn_lm/model.py
@@ -129,9 +129,7 @@ class RnnLmModel(torch.nn.Module):
         tokens_eos = add_eos(tokens, eos_id)
         sos_tokens_row_splits = sos_tokens.shape.row_splits(1)
 
-        sentence_lengths = (
-            sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
-        )
+        sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
 
         x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id)
         y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id)
@@ -161,12 +159,12 @@ class RnnLmModel(torch.nn.Module):
         if state:
             h, c = state
         else:
-            h = torch.zeros(
-                self.rnn.num_layers, batch_size, self.rnn.input_size
-            ).to(device)
-            c = torch.zeros(
-                self.rnn.num_layers, batch_size, self.rnn.input_size
-            ).to(device)
+            h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(
+                device
+            )
+            c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to(
+                device
+            )
 
         embedding = self.input_embedding(tokens)
         rnn_out, states = self.rnn(embedding, (h, c))
@@ -181,12 +179,8 @@ class RnnLmModel(torch.nn.Module):
         if state:
             h, c = state
         else:
-            h = torch.zeros(
-                self.rnn.num_layers, batch_size, self.rnn.input_size
-            )
-            c = torch.zeros(
-                self.rnn.num_layers, batch_size, self.rnn.input_size
-            )
+            h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size)
+            c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size)
 
         device = next(self.parameters()).device
 
@@ -194,9 +188,7 @@ class RnnLmModel(torch.nn.Module):
         tokens_eos = add_eos(tokens, eos_id)
         sos_tokens_row_splits = sos_tokens.shape.row_splits(1)
 
-        sentence_lengths = (
-            sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
-        )
+        sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
 
         x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id)
         y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id)
diff --git a/icefall/rnn_lm/train.py b/icefall/rnn_lm/train.py
index bb5f03fb9..e17b50332 100755
--- a/icefall/rnn_lm/train.py
+++ b/icefall/rnn_lm/train.py
@@ -446,17 +446,13 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
                 tb_writer.add_scalar(
                     "train/current_ppl", this_batch_ppl, params.batch_idx_train
                 )
 
-                tb_writer.add_scalar(
-                    "train/tot_ppl", tot_ppl, params.batch_idx_train
-                )
+                tb_writer.add_scalar("train/tot_ppl", tot_ppl, params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -471,8 +467,7 @@ def train_one_epoch(
 
             valid_ppl = math.exp(valid_info["loss"] / valid_info["frames"])
             logging.info(
-                f"Epoch {params.cur_epoch}, validation: {valid_info}, "
-                f"ppl: {valid_ppl}"
+                f"Epoch {params.cur_epoch}, validation: {valid_info}, ppl: {valid_ppl}"
             )
 
             if tb_writer is not None:
diff --git a/icefall/shared/make_kn_lm.py b/icefall/shared/make_kn_lm.py
index c2edd823e..a3bf1ef4c 100755
--- a/icefall/shared/make_kn_lm.py
+++ b/icefall/shared/make_kn_lm.py
@@ -15,30 +15,50 @@
 # The data structure is based on: kaldi/egs/wsj/s5/utils/lang/make_phone_lm.py
 # The smoothing algorithm is based on: http://www.speech.sri.com/projects/srilm/manpages/ngram-discount.7.html
 
-import sys
-import os
-import re
+import argparse
 import io
 import math
-import argparse
+import os
+import re
+import sys
 from collections import Counter, defaultdict
 
-
-parser = argparse.ArgumentParser(description="""
+parser = argparse.ArgumentParser(
+    description="""
     Generate kneser-ney language model as arpa format. By default,
     it will read the corpus from standard input, and output to standard output.
-    """)
-parser.add_argument("-ngram-order", type=int, default=4, choices=[2, 3, 4, 5, 6, 7], help="Order of n-gram")
+    """
+)
+parser.add_argument(
+    "-ngram-order",
+    type=int,
+    default=4,
+    choices=[2, 3, 4, 5, 6, 7],
+    help="Order of n-gram",
+)
 parser.add_argument("-text", type=str, default=None, help="Path to the corpus file")
-parser.add_argument("-lm", type=str, default=None, help="Path to output arpa file for language models")
-parser.add_argument("-verbose", type=int, default=0, choices=[0, 1, 2, 3, 4, 5], help="Verbose level")
+parser.add_argument(
+    "-lm",
+    type=str,
+    default=None,
+    help="Path to output arpa file for language models",
+)
+parser.add_argument(
+    "-verbose",
+    type=int,
+    default=0,
+    choices=[0, 1, 2, 3, 4, 5],
+    help="Verbose level",
+)
 args = parser.parse_args()
 
-default_encoding = "latin-1"  # For encoding-agnostic scripts, we assume byte stream as input.
-                              # Need to be very careful about the use of strip() and split()
-                              # in this case, because there is a latin-1 whitespace character
-                              # (nbsp) which is part of the unicode encoding range.
-                              # Ref: kaldi/egs/wsj/s5/utils/lang/bpe/prepend_words.py @ 69cd717
+default_encoding = (
+    "latin-1"  # For encoding-agnostic scripts, we assume byte stream as input.
+)
+# Need to be very careful about the use of strip() and split()
+# in this case, because there is a latin-1 whitespace character
+# (nbsp) which is part of the unicode encoding range.
+# Ref: kaldi/egs/wsj/s5/utils/lang/bpe/prepend_words.py @ 69cd717
 strip_chars = " \t\r\n"
 whitespace = re.compile("[ \t]+")
 
@@ -52,7 +72,9 @@ class CountsForHistory:
         # The 'lambda: defaultdict(float)' is an anonymous function taking no
         # arguments that returns a new defaultdict(float).
         self.word_to_count = defaultdict(int)
-        self.word_to_context = defaultdict(set)  # using a set to count the number of unique contexts
+        self.word_to_context = defaultdict(
+            set
+        )  # using a set to count the number of unique contexts
         self.word_to_f = dict()  # discounted probability
         self.word_to_bow = dict()  # back-off weight
         self.total_count = 0
@@ -62,10 +84,15 @@ class CountsForHistory:
 
     def __str__(self):
         # e.g. returns ' total=12: 3->4, 4->6, -1->2'
-        return ' total={0}: {1}'.format(
+        return " total={0}: {1}".format(
             str(self.total_count),
-            ', '.join(['{0} -> {1}'.format(word, count)
-                      for word, count in self.word_to_count.items()]))
+            ", ".join(
+                [
+                    "{0} -> {1}".format(word, count)
+                    for word, count in self.word_to_count.items()
+                ]
+            ),
+        )
 
     def add_count(self, predicted_word, context_word, count):
         assert count >= 0
@@ -85,7 +112,7 @@ class NgramCounts:
     # accumulating the 4-gram count for the '8' in the sequence '5 6 7 8', we'd
     # do as follows: self.counts[3][[5,6,7]][8] += 1.0 where the [3] indexes an
     # array, the [[5,6,7]] indexes a dict, and the [8] indexes a dict.
-    def __init__(self, ngram_order, bos_symbol='', eos_symbol=''):
+    def __init__(self, ngram_order, bos_symbol="", eos_symbol=""):
         assert ngram_order >= 2
 
         self.ngram_order = ngram_order
@@ -103,39 +130,48 @@ class NgramCounts:
     # would be (6,7,8) and 'predicted_word' would be 9; 'count' would be
     # 1.
     def add_count(self, history, predicted_word, context_word, count):
-        self.counts[len(history)][history].add_count(predicted_word, context_word, count)
+        self.counts[len(history)][history].add_count(
+            predicted_word, context_word, count
+        )
 
     # 'line' is a string containing a sequence of integer word-ids.
     # This function adds the un-smoothed counts from this line of text.
     def add_raw_counts_from_line(self, line):
-        if line == '':
+        if line == "":
             words = [self.bos_symbol, self.eos_symbol]
         else:
             words = [self.bos_symbol] + whitespace.split(line) + [self.eos_symbol]
 
         for i in range(len(words)):
-            for n in range(1, self.ngram_order+1):
+            for n in range(1, self.ngram_order + 1):
                 if i + n > len(words):
                     break
-                ngram = words[i: i + n]
+                ngram = words[i : i + n]
                 predicted_word = ngram[-1]
-                history = tuple(ngram[: -1])
+                history = tuple(ngram[:-1])
                 if i == 0 or n == self.ngram_order:
                     context_word = None
                 else:
-                    context_word = words[i-1]
+                    context_word = words[i - 1]
 
                 self.add_count(history, predicted_word, context_word, 1)
 
     def add_raw_counts_from_standard_input(self):
         lines_processed = 0
-        infile = io.TextIOWrapper(sys.stdin.buffer, encoding=default_encoding)  # byte stream as input
+        infile = io.TextIOWrapper(
+            sys.stdin.buffer, encoding=default_encoding
+        )  # byte stream as input
         for line in infile:
             line = line.strip(strip_chars)
             self.add_raw_counts_from_line(line)
             lines_processed += 1
         if lines_processed == 0 or args.verbose > 0:
-            print("make_phone_lm.py: processed {0} lines of input".format(lines_processed), file=sys.stderr)
+            print(
+                "make_phone_lm.py: processed {0} lines of input".format(
+                    lines_processed
+                ),
+                file=sys.stderr,
+            )
 
     def add_raw_counts_from_file(self, filename):
         lines_processed = 0
@@ -145,7 +181,12 @@ class NgramCounts:
                 self.add_raw_counts_from_line(line)
                 lines_processed += 1
         if lines_processed == 0 or args.verbose > 0:
-            print("make_phone_lm.py: processed {0} lines of input".format(lines_processed), file=sys.stderr)
+            print(
+                "make_phone_lm.py: processed {0} lines of input".format(
+                    lines_processed
+                ),
+                file=sys.stderr,
+            )
 
     def cal_discounting_constants(self):
         # For each order N of N-grams, we calculate discounting constant D_N = n1_N / (n1_N + 2 * n2_N),
@@ -153,9 +194,11 @@ class NgramCounts:
         # This constant is used similarly to absolute discounting.
         # Return value: d is a list of floats, where d[N+1] = D_N
 
-        self.d = [0]  # for the lowest order, i.e., 1-gram, we do not need to discount, thus the constant is 0
-                      # This is a special case: as we currently assumed having seen all vocabularies in the dictionary,
-                      # but perhaps this is not the case for some other scenarios.
+        self.d = [
+            0
+        ]  # for the lowest order, i.e., 1-gram, we do not need to discount, thus the constant is 0
+        # This is a special case: as we currently assumed having seen all vocabularies in the dictionary,
+        # but perhaps this is not the case for some other scenarios.
         for n in range(1, self.ngram_order):
             this_order_counts = self.counts[n]
             n1 = 0
@@ -165,9 +208,11 @@ class NgramCounts:
                 n1 += stat[1]
                 n2 += stat[2]
             assert n1 + 2 * n2 > 0
-            self.d.append(max(0.1, n1 * 1.0) / (n1 + 2 * n2))   # We are doing this max(0.001, xxx) to avoid zero discounting constant D due to n1=0, 
-                                                                # which could happen if the number of symbols is small.
-                                                                # Otherwise, zero discounting constant can cause division by zero in computing BOW.
+            self.d.append(
+                max(0.1, n1 * 1.0) / (n1 + 2 * n2)
+            )  # We are doing this max(0.001, xxx) to avoid zero discounting constant D due to n1=0,
+            # which could happen if the number of symbols is small.
+            # Otherwise, zero discounting constant can cause division by zero in computing BOW.
 
     def cal_f(self):
         # f(a_z) is a probability distribution of word sequence a_z.
@@ -182,7 +227,9 @@ class NgramCounts:
         this_order_counts = self.counts[n]
         for hist, counts_for_hist in this_order_counts.items():
             for w, c in counts_for_hist.word_to_count.items():
-                counts_for_hist.word_to_f[w] = max((c - self.d[n]), 0) * 1.0 / counts_for_hist.total_count
+                counts_for_hist.word_to_f[w] = (
+                    max((c - self.d[n]), 0) * 1.0 / counts_for_hist.total_count
+                )
 
         # lower order N-grams
         for n in range(0, self.ngram_order - 1):
@@ -196,11 +243,17 @@ class NgramCounts:
                 if n_star_star != 0:
                     for w in counts_for_hist.word_to_count.keys():
                         n_star_z = len(counts_for_hist.word_to_context[w])
-                        counts_for_hist.word_to_f[w] = max((n_star_z - self.d[n]), 0) * 1.0 / n_star_star
+                        counts_for_hist.word_to_f[w] = (
+                            max((n_star_z - self.d[n]), 0) * 1.0 / n_star_star
+                        )
                 else:  # patterns begin with , they do not have "modified count", so use raw count instead
                     for w in counts_for_hist.word_to_count.keys():
                         n_star_z = counts_for_hist.word_to_count[w]
-                        counts_for_hist.word_to_f[w] = max((n_star_z - self.d[n]), 0) * 1.0 / counts_for_hist.total_count
+                        counts_for_hist.word_to_f[w] = (
+                            max((n_star_z - self.d[n]), 0)
+                            * 1.0
+                            / counts_for_hist.total_count
+                        )
 
     def cal_bow(self):
         # Backoff weights are only necessary for ngrams which form a prefix of a longer ngram.
@@ -240,12 +293,18 @@ class NgramCounts:
                         sum_z1_f_z = 0
                         _ = a_[1:]
                         _counts_for_hist = self.counts[len(_)][_]
-                        for u in a_counts_for_hist.word_to_count.keys():  # Should be careful here: what is Z1
+                        for (
+                            u
+                        ) in (
+                            a_counts_for_hist.word_to_count.keys()
+                        ):  # Should be careful here: what is Z1
                             sum_z1_f_z += _counts_for_hist.word_to_f[u]
 
                         if sum_z1_f_z < 1:
                             # assert sum_z1_f_a_z < 1
-                            counts_for_hist.word_to_bow[w] = (1.0 - sum_z1_f_a_z) / (1.0 - sum_z1_f_z)
+                            counts_for_hist.word_to_bow[w] = (1.0 - sum_z1_f_a_z) / (
+                                1.0 - sum_z1_f_z
+                            )
                         else:
                             counts_for_hist.word_to_bow[w] = None
 
@@ -259,7 +318,9 @@ class NgramCounts:
                     ngram = " ".join(hist) + " " + w
                     ngram = ngram.strip(strip_chars)
 
-                    res.append("{0}\t{1}".format(ngram, counts_for_hist.word_to_count[w]))
+                    res.append(
+                        "{0}\t{1}".format(ngram, counts_for_hist.word_to_count[w])
+                    )
         res.sort(reverse=True)
         for r in res:
             print(r)
@@ -322,27 +383,40 @@ class NgramCounts:
                     if bow is None:
                         res.append("{1}\t{0}".format(ngram, math.log(f, 10)))
                     else:
-                        res.append("{1}\t{0}\t{2}".format(ngram, math.log(f, 10), math.log(bow, 10)))
+                        res.append(
+                            "{1}\t{0}\t{2}".format(
+                                ngram, math.log(f, 10), math.log(bow, 10)
+                            )
+                        )
         res.sort(reverse=True)
         for r in res:
             print(r)
 
-    def print_as_arpa(self, fout=io.TextIOWrapper(sys.stdout.buffer, encoding='latin-1')):
+    def print_as_arpa(
+        self, fout=io.TextIOWrapper(sys.stdout.buffer, encoding="latin-1")
+    ):
         # print as ARPA format.
 
-        print('\\data\\', file=fout)
+        print("\\data\\", file=fout)
         for hist_len in range(self.ngram_order):
             # print the number of n-grams.
-            print('ngram {0}={1}'.format(
-                hist_len + 1,
-                sum([len(counts_for_hist.word_to_f) for counts_for_hist in self.counts[hist_len].values()])),
-                file=fout
+            print(
+                "ngram {0}={1}".format(
+                    hist_len + 1,
+                    sum(
+                        [
+                            len(counts_for_hist.word_to_f)
+                            for counts_for_hist in self.counts[hist_len].values()
+                        ]
+                    ),
+                ),
+                file=fout,
             )
 
-        print('', file=fout)
+        print("", file=fout)
 
         for hist_len in range(self.ngram_order):
-            print('\\{0}-grams:'.format(hist_len + 1), file=fout)
+            print("\\{0}-grams:".format(hist_len + 1), file=fout)
 
             this_order_counts = self.counts[hist_len]
             for hist, counts_for_hist in this_order_counts.items():
@@ -354,12 +428,12 @@ class NgramCounts:
                     if prob == 0:  # f() is always 0
                         prob = 1e-99
 
-                    line = '{0}\t{1}'.format('%.7f' % math.log10(prob), ' '.join(ngram))
+                    line = "{0}\t{1}".format("%.7f" % math.log10(prob), " ".join(ngram))
                     if bow is not None:
-                        line += '\t{0}'.format('%.7f' % math.log10(bow))
+                        line += "\t{0}".format("%.7f" % math.log10(bow))
                     print(line, file=fout)
-            print('', file=fout)
-        print('\\end\\', file=fout)
+            print("", file=fout)
+        print("\\end\\", file=fout)
 
 
 if __name__ == "__main__":
@@ -379,5 +453,5 @@ if __name__ == "__main__":
     if args.lm is None:
         ngram_counts.print_as_arpa()
     else:
-        with open(args.lm, 'w', encoding=default_encoding) as f:
+        with open(args.lm, "w", encoding=default_encoding) as f:
             ngram_counts.print_as_arpa(fout=f)
diff --git a/icefall/utils.py b/icefall/utils.py
index c502cb4d8..0beb94b2e 100644
--- a/icefall/utils.py
+++ b/icefall/utils.py
@@ -130,9 +130,7 @@ def setup_logger(
         formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s"  # noqa
         log_filename = f"{log_filename}-{date_time}-{rank}"
     else:
-        formatter = (
-            "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-        )
+        formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
         log_filename = f"{log_filename}-{date_time}"
 
     os.makedirs(os.path.dirname(log_filename), exist_ok=True)
@@ -280,13 +278,9 @@ def get_texts_with_timestamp(
     """
     if isinstance(best_paths.aux_labels, k2.RaggedTensor):
         all_aux_shape = (
-            best_paths.arcs.shape()
-            .remove_axis(1)
-            .compose(best_paths.aux_labels.shape)
-        )
-        all_aux_labels = k2.RaggedTensor(
-            all_aux_shape, best_paths.aux_labels.values
+            best_paths.arcs.shape().remove_axis(1).compose(best_paths.aux_labels.shape)
         )
+        all_aux_labels = k2.RaggedTensor(all_aux_shape, best_paths.aux_labels.values)
         # remove 0's and -1's.
         aux_labels = best_paths.aux_labels.remove_values_leq(0)
         # TODO: change arcs.shape() to arcs.shape
@@ -355,9 +349,7 @@ def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]:
     # arc.shape() has axes [fsa][state][arc], we remove "state"-axis here
     token_shape = best_paths.arcs.shape().remove_axis(1)
     # token_shape has axes [fsa][arc]
-    tokens = k2.RaggedTensor(
-        token_shape, getattr(best_paths, kind).contiguous()
-    )
+    tokens = k2.RaggedTensor(token_shape, getattr(best_paths, kind).contiguous())
     tokens = tokens.remove_values_eq(-1)
     return tokens.tolist()
 
@@ -578,9 +570,7 @@ def write_error_stats(
             f"{cut_id}:\t"
             + " ".join(
                 (
-                    ref_word
-                    if ref_word == hyp_word
-                    else f"({ref_word}->{hyp_word})"
+                    ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
                     for ref_word, hyp_word in ali
                 )
             ),
@@ -590,9 +580,7 @@ def write_error_stats(
     print("", file=f)
     print("SUBSTITUTIONS: count ref -> hyp", file=f)
 
-    for count, (ref, hyp) in sorted(
-        [(v, k) for k, v in subs.items()], reverse=True
-    ):
+    for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
         print(f"{count}   {ref} -> {hyp}", file=f)
 
     print("", file=f)
@@ -606,9 +594,7 @@ def write_error_stats(
         print(f"{count}   {hyp}", file=f)
 
     print("", file=f)
-    print(
-        "PER-WORD STATS: word  corr tot_errs count_in_ref count_in_hyp", file=f
-    )
+    print("PER-WORD STATS: word  corr tot_errs count_in_ref count_in_hyp", file=f)
     for _, word, counts in sorted(
         [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
     ):
@@ -783,9 +769,7 @@ def write_error_stats_with_timestamps(
             f"{cut_id}:\t"
             + " ".join(
                 (
-                    ref_word
-                    if ref_word == hyp_word
-                    else f"({ref_word}->{hyp_word})"
+                    ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
                     for ref_word, hyp_word in ali
                 )
             ),
@@ -795,9 +779,7 @@ def write_error_stats_with_timestamps(
     print("", file=f)
     print("SUBSTITUTIONS: count ref -> hyp", file=f)
 
-    for count, (ref, hyp) in sorted(
-        [(v, k) for k, v in subs.items()], reverse=True
-    ):
+    for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
         print(f"{count}   {ref} -> {hyp}", file=f)
 
     print("", file=f)
@@ -811,9 +793,7 @@ def write_error_stats_with_timestamps(
         print(f"{count}   {hyp}", file=f)
 
     print("", file=f)
-    print(
-        "PER-WORD STATS: word  corr tot_errs count_in_ref count_in_hyp", file=f
-    )
+    print("PER-WORD STATS: word  corr tot_errs count_in_ref count_in_hyp", file=f)
     for _, word, counts in sorted(
         [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
     ):
@@ -883,9 +863,7 @@ class MetricsTracker(collections.defaultdict):
             if k == "frames" or k == "utterances":
                 continue
             norm_value = (
-                float(v) / num_frames
-                if "utt_" not in k
-                else float(v) / num_utterances
+                float(v) / num_frames if "utt_" not in k else float(v) / num_utterances
             )
             ans.append((k, norm_value))
         return ans
@@ -919,9 +897,7 @@ class MetricsTracker(collections.defaultdict):
             tb_writer.add_scalar(prefix + k, v, batch_idx)
 
 
-def concat(
-    ragged: k2.RaggedTensor, value: int, direction: str
-) -> k2.RaggedTensor:
+def concat(ragged: k2.RaggedTensor, value: int, direction: str) -> k2.RaggedTensor:
     """Prepend a value to the beginning of each sublist or append a value.
     to the end of each sublist.
 
@@ -967,8 +943,8 @@ def concat(
         ans = k2.ragged.cat([ragged, pad], axis=1)
     else:
         raise ValueError(
-            f'Unsupported direction: {direction}. " \
-            "Expect either "left" or "right"'
+            f'Unsupported direction: {direction}. "             "Expect either "left"'
+            ' or "right"'
         )
     return ans
 
@@ -1093,9 +1069,7 @@ def linf_norm(x):
     return torch.max(torch.abs(x))
 
 
-def measure_weight_norms(
-    model: nn.Module, norm: str = "l2"
-) -> Dict[str, float]:
+def measure_weight_norms(model: nn.Module, norm: str = "l2") -> Dict[str, float]:
     """
     Compute the norms of the model's parameters.
 
@@ -1118,9 +1092,7 @@ def measure_weight_norms(
         return norms
 
 
-def measure_gradient_norms(
-    model: nn.Module, norm: str = "l1"
-) -> Dict[str, float]:
+def measure_gradient_norms(model: nn.Module, norm: str = "l1") -> Dict[str, float]:
     """
     Compute the norms of the gradients for each of model's parameters.
 
@@ -1405,9 +1377,7 @@ def parse_hyp_and_timestamp(
         use_word_table = True
 
     for i in range(N):
-        time = convert_timestamp(
-            res.timestamps[i], subsampling_factor, frame_shift_ms
-        )
+        time = convert_timestamp(res.timestamps[i], subsampling_factor, frame_shift_ms)
         if use_word_table:
             words = [word_table[i] for i in res.hyps[i]]
         else:
diff --git a/pyproject.toml b/pyproject.toml
index b4f8c3377..3183055d4 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -3,7 +3,7 @@ profile = "black"
 skip = ["icefall/__init__.py"]
 
 [tool.black]
-line-length = 80
+line-length = 88
 exclude = '''
 /(
     \.git
diff --git a/setup.py b/setup.py
index 6c720e121..ccd2503ff 100644
--- a/setup.py
+++ b/setup.py
@@ -1,8 +1,9 @@
 #!/usr/bin/env python3
 
-from setuptools import find_packages, setup
 from pathlib import Path
 
+from setuptools import find_packages, setup
+
 icefall_dir = Path(__file__).parent
 install_requires = (icefall_dir / "requirements.txt").read_text().splitlines()
 
diff --git a/test/test_checkpoint.py b/test/test_checkpoint.py
index 511a11c23..34e829642 100644
--- a/test/test_checkpoint.py
+++ b/test/test_checkpoint.py
@@ -20,11 +20,7 @@ import pytest
 import torch
 import torch.nn as nn
 
-from icefall.checkpoint import (
-    average_checkpoints,
-    load_checkpoint,
-    save_checkpoint,
-)
+from icefall.checkpoint import average_checkpoints, load_checkpoint, save_checkpoint
 
 
 @pytest.fixture
diff --git a/test/test_decode.py b/test/test_decode.py
index 97964ac67..4c2e192a7 100644
--- a/test/test_decode.py
+++ b/test/test_decode.py
@@ -23,6 +23,7 @@ You can run this file in one of the two ways:
 """
 
 import k2
+
 from icefall.decode import Nbest
 
 
diff --git a/test/test_graph_compiler.py b/test/test_graph_compiler.py
index ccfb57d49..10443cf22 100644
--- a/test/test_graph_compiler.py
+++ b/test/test_graph_compiler.py
@@ -154,9 +154,7 @@ class TestCtcTrainingGraphCompiler(object):
         fsas = k2.Fsa.from_fsas([fsa1, fsa2])
 
         decoding_graph = k2.arc_sort(decoding_graph)
-        lattice = k2.intersect(
-            decoding_graph, fsas, treat_epsilons_specially=False
-        )
+        lattice = k2.intersect(decoding_graph, fsas, treat_epsilons_specially=False)
         lattice = k2.connect(lattice)
 
         aux_labels0 = lattice[0].aux_labels[:-1]
diff --git a/test/test_utils.py b/test/test_utils.py
index 6a9ce7853..31f06bd51 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -50,9 +50,7 @@ def test_encode_supervisions(sup):
     assert torch.all(
         torch.eq(
             supervision_segments,
-            torch.tensor(
-                [[1, 0, 30 // 4], [0, 0, 20 // 4], [2, 9 // 4, 10 // 4]]
-            ),
+            torch.tensor([[1, 0, 30 // 4], [0, 0, 20 // 4], [2, 9 // 4, 10 // 4]]),
         )
     )
     assert texts == ["two", "one", "three"]

From d89766d85dbb023a8fcb47221545d00b1a015c69 Mon Sep 17 00:00:00 2001
From: Desh Raj 
Date: Wed, 16 Nov 2022 13:10:55 -0500
Subject: [PATCH 2/3] add git blame ignore revs file

---
 .git-blame-ignore-revs | 2 ++
 1 file changed, 2 insertions(+)
 create mode 100644 .git-blame-ignore-revs

diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs
new file mode 100644
index 000000000..c5908fc89
--- /dev/null
+++ b/.git-blame-ignore-revs
@@ -0,0 +1,2 @@
+# Migrate to 88 characters per line (see: https://github.com/lhotse-speech/lhotse/issues/890)
+d110b04ad389134c82fa314e3aafc7b40043efb0

From 7a8e8e735d21bd9d992e291e6c39c922154168b5 Mon Sep 17 00:00:00 2001
From: Desh Raj 
Date: Wed, 16 Nov 2022 14:43:21 -0500
Subject: [PATCH 3/3] change click version in pre-commit

---
 .pre-commit-config.yaml | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index e2055801b..5cb213327 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -4,7 +4,7 @@ repos:
     hooks:
       - id: black
         args: ["--line-length=88"]
-        additional_dependencies: ['click==8.0.1']
+        additional_dependencies: ['click==8.1.0']
         exclude: icefall\/__init__\.py
 
   - repo: https://github.com/PyCQA/flake8